degree project, predict patient outcome with RNA-seq data

public public 1yr ago 0 bookmarks

The aims of this project:

  • To predict patient outcomes based on gene and transcript expression.

  • To explore and improve methods for feature selection which extract the informative genes/transcripts associated with breast cancer relapse.

With the aim to select informative features associated with breast cancer patient outcomes, a customized ensemble feature selection model based on penalized Cox proportional hazards (PH) model was developed and compared with the univariate method and single-run Cox PH model. This method can be employed to select features in unbalanced, ultra-high-dimensional, time-to-event data.

Introduction

Breast cancer

https://gco.iarc.fr/today/online-analysis-pie

Breast cancer overtaking lung cancer became the most common cancer worldwide. In 2020, 2.26 million new cases and 685 000 deaths were caused by breast cancer, which was the fifth leading cause of cancer death.

SCAN-B

The Sweden Cancerome Analysis Network–Breast (SCAN-B) Initiative was initially proposed in 2009 and began enrolling patients in 2010. Within SCAN-B, breast tumors from hospitals across a wide geography of Sweden are routinely being processed and RNA-sequenced. Up to now, more than 17 thousand patients have been enrolled in SCAN-B, and more than 11 thousand tumors have been RNA-sequenced, which provides a rich resource for breast cancer research. To our current knowledge, SCAN-B is the largest RNA-seq breast cancer study in the world!

Survival analysis and time to event data

Survival analysis is a type of regression problem which tries to establish a connection between covariates and the time of an event. Here the definition of the start point is usually the date of primary treatment. But in SCAN-B, the start point is the date of diagnosis. As for the endpoint, according to the specific question, it can be overall survival (OS) which includes all kinds of death events, relapse-free survival (RFS), which includes death and relapse event, and recurrence-free interval (RFi), which includes recurrence events only.

A common issue in survival analysis is that the data is censored.

https://scikit-survival.readthedocs.io/en/stable/user_guide/understanding_predictions.html

The events for patients B and D are recorded. But for patients A, C, and E, the only available information is they are event-free up to their last follow-up. So they are censored.

MATERIALS AND USAGE

Materials

The breast cancer data is obtained form SCAN-B, which contains 2874 ER+/HER2- samples , which are used to predict patient outcomes.

RFi event RFi event RFi event RFi event
Train Validation Train Validation
With event 116 27 342 95
Without event 1139 339 1957 480
Censoring ratio 90.8% 92.6% 85.1% 83.5%

Usage

All the scripts are stored in scripts and managed with Snakemake . It is recommend to use cluster to carry out the project.

How to start:

# A conda enviorment is available in '/env/ballgown'
# Unpack environment into directory `ballgown`
cd env
mkdir -p ballgown
tar -xzf ballgown.tar.gz -C ballgown
# Activate the environment.
source ballgown/bin/activate
# Cluster execution
snakemake --profile lsens

Workflow design

  • de novo assembly pipeline

image-20220222180811657

A1: The StringTie expression estimation mode (-e). The reference annotation file is recored, the main output file is a gene abundances file. With this option, no "novel" transcript assemblies (isoforms) will be produced, and read alignments not overlapping any of the given reference transcripts will be ignored.

A2: The StringTie merge function. Without -e option, StringTie generates pre-assembled transcripts for each sample, all the reads will be taken into consideration. Then use the merge function to generate a nonredundant GTF file that contains the de novo transcripts information. Next, estimate the expression of novel transcripts with de novo assembly annotation file, and script/extract.py is used to extract the FPKM of the gene and transcript into a specific format.

  • Workflow design

image-20220223150236487

Univariate method: In script/cox_unitest.py , CoxPHFitter() is used to test the features one by one and selects the significant features based on the adjusted p-value ( script/RDR.R ).

One run analysis: In script/one_run.py , CoxnetSurvivalAnalysis() is used to select features once. The LASSO ratio of the Elastic Net is set to 0.9.

Ensemble method: In script/pre_selection.py , a customized selection model using the elastic net penalized cox model as the based learner is designed, to processing ultra-high dimensional data.

The selected feature lists are named label-best.feature

Ensemble method

image-20220223153135811

When processing ultra-high-dimensional data, the high-lasso-ratio elastic net method becomes unstable. That is why we want to use the idea of the ensemble method to avoid this issue. Firstly, the features are randomly split into m groups, where m is based on the number of features n divided by the number of samples p . This “square” data then is inputted into the base learner which consists of 90% LASSO and 10% Ridge. Each base learner carries out a selection and is adjusted by age and treatment, summarizing all the results from all the base learners. One round selection is finished. To cover the combination of features split into a random subgroup as more as possible, this selection needs to be repeated k times. After collecting all the selected features from k iterations, the feature-selected frequency is counted. A cutoff is defined by the ratio of the maximum frequency, which is considered to be a hyper-parameter and can be used to improve the performance of model.

image-20220223153628157

In order to prevent overfitting, the samples with true events and false events are split into five folds respectively with a global fixed random status to make each base learner comparable. And to deal with unbalanced data, the false events in each fold are randomly excluded to end up with the same number of events as the corresponding true events fold. This random undersampling is performed for each base learner to make the best use of all the false-event samples.

RESULTS

RFi (recurrence-free interval) prediction

image-20220223154734872

All feature selection methods give a better performance than baseline, especially the two red lines from de novo gene set and the blue line using transcripts dataset.

RFS (relapse-free survival) prediction

image-20220223155001120

Performance on gene set and transcript set is close. The feature sets selected by the univariate method (the green lines) are considerably better than baseline and others.

Conclusion

  • It is possible to use transcript expression levels to predict the patient outcome, but the collinearity is an issue

  • Univariate method is a good tool to select the features and can be a good method to reduce the redundancy

  • For survival analysis, good quality of labels is the key to a successful prediction

  • When processing ultra-high-dimensional data, the one-run method is unstable sometimes which can not select the feature efficiently

  • This project forms a basis for further work on the development of predictive signatures for breast cancer patient outcome

Code Snippets

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import os
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

path = os.path.split(snakemake.output[0])[0]
label = os.path.split(snakemake.output[0])[1].split('-')[0]

N_round = [i for i in range(1,10)]
dataset = [path for i in range(9)]
performance = []
features = []

for i in N_round:
    with open(os.path.join(path, label+f'-describe.Round{i}')) as f:
        performance.append(float(f.readline().split(',')[0][6:]))
    features.append(pd.read_csv(os.path.join(path, label+f'-features.Round{i}')).shape[0])

pd.read_csv(os.path.join(path, label+f'-features.Round{performance.index(max(performance))+1}')).to_csv(snakemake.output[0], index=False)
pd.DataFrame({'dataset':dataset, 'N_round':N_round, 'performance':performance, 'features':features}).to_csv(snakemake.output[1], index=False)
1
2
3
import pandas as pd

pd.concat([pd.read_csv(snakemake.input[i]) for i in range(4)]).to_csv(snakemake.output[0], index=False)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import pandas as pd
import numpy as np
import os
from sksurv.linear_model import CoxnetSurvivalAnalysis
from sksurv.preprocessing import OneHotEncoder
from sksurv.metrics import (
    as_concordance_index_ipcw_scorer,
    as_cumulative_dynamic_auc_scorer,
    as_integrated_brier_score_scorer,
    cumulative_dynamic_auc
)
from sklearn.model_selection import GridSearchCV, KFold, train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, MinMaxScaler


event = snakemake.params[0]
days = snakemake.params[1]


def load_data(dataP, spInfo, features):
    data = pd.read_csv(dataP).set_index('ID')
    sample_info = pd.read_csv(spInfo)
    data.drop(['MEAN','VAR'], axis=1, inplace=True)
    data = data.T.reset_index().rename(columns={'index':'SAMPLE'})
    data.columns.names = [None]
    data = data[['SAMPLE']+features]
    if 'RFi' in event:
        data = pd.merge(sample_info[['SAMPLE', 'RFi_days', 'RFi_event', 'RFS_all_event', 'age', 'treatGroup']], data, on='SAMPLE')
        data = data[data.RFi_days.notna()].reset_index(drop=True)
        data = data[data.apply(lambda x: False if (x['RFi_event'] == 0 and x['RFi_days'] < 600) or (x['RFi_event'] == False and x['RFS_all_event'] == True) else True, axis = 1)].reset_index(drop=True)
        data.drop(['RFS_all_event'], axis=1, inplace=True)
    else:
        data = pd.merge(sample_info[['SAMPLE', days, event, 'age', 'treatGroup']], data, on='SAMPLE')
        data = data[data[event].notna()].reset_index(drop=True)
        data = data[data[days].notna()].reset_index(drop=True)
    x, y = data.iloc[::, 3:], data.iloc[::, 1:3]
    treatment = x.treatGroup.astype('category')
    x.drop('treatGroup', axis=1, inplace=True)
    x = pd.DataFrame(StandardScaler().fit_transform(x), columns=x.columns).dropna(axis=1)
    x['treatGroup'] = treatment
    x = OneHotEncoder().fit_transform(x).replace(np.nan, 0)
    return x, y

output = snakemake.output[0]
method = output.split('_')[-2]

if method == 'uni':
    features = pd.read_csv(os.path.join(os.path.dirname(output), f'{event.split("_")[0]}_uni_fdr.csv'))
    features = features[features['p.adjust']<0.05]
    features = features['fetures'].to_list()
elif method == 'ens':
    features = pd.read_csv(os.path.join(os.path.dirname(output), f'{event.split("_")[0]}-best.feature'))
    features = features['Feature'].to_list()
elif method == '1run':
    which = output.split('/')[0]
    dataset = lambda x: 'genes' if 'gs' in x else 'transcripts'
    features = pd.read_csv(os.path.join(os.path.dirname(output), f'{event.split("_")[0]}_1run_{which}_{dataset(output.split("/")[1])}_best_coefs.csv'), index_col=0).reset_index()
    features = features['index'].to_list()
    features = [x for x in features if x != 'age' and not x.startswith('treat')]
else:
    features = []

x, y = load_data(snakemake.input[0], snakemake.input[1], features)
os_type = {'names':('events', 'time'), 'formats':('?', '<f8')}
y = np.array([(a, b) for a, b in zip(y[event].astype(bool), y[days])], dtype=os_type)

xtrain, xval, ytrain, yval = train_test_split(x, y, test_size=0.2, random_state=954)

lower, upper = np.percentile(pd.DataFrame(ytrain).time, [25, 75])
grids_times = np.arange(lower, upper + 1)

cv = KFold(n_splits=3, shuffle=True, random_state=0)
if method == 'uni':
    alphas = 10. ** np.linspace(-1.8, -1, 5)
else:
    alphas = 10. ** np.linspace(-1.9, -1.5, 5)

grids_cindex = GridSearchCV(
    as_concordance_index_ipcw_scorer(CoxnetSurvivalAnalysis(l1_ratio=0.9, tol=1e-15, max_iter=10000), tau=grids_times[-1]),
    param_grid={"estimator__alphas": [[v] for v in alphas]},
    cv=cv,
    error_score=0.5,
    n_jobs=16).fit(xtrain, ytrain)

cph = CoxnetSurvivalAnalysis(l1_ratio=0.9, alphas=[grids_cindex.best_params_["estimator__alphas"]])
cph.fit(xtrain, ytrain)

best_coefs = pd.DataFrame(cph.coef_, index=xtrain.columns, columns=['coefficient'])
non_zero = np.sum(best_coefs.iloc[:, 0] != 0)
non_zero_coefs = best_coefs.query("coefficient != 0")

va_times = np.arange(np.percentile(pd.DataFrame(yval).time, [5, 95])[0], np.percentile(pd.DataFrame(yval).time, [5, 95])[1], 60)
cph_risk_scores = cph.predict(xval)
auc, mean_auc = cumulative_dynamic_auc(ytrain, yval, cph_risk_scores, va_times)



baseline_feature = [x for x in xtrain.columns if x == 'age' or x.startswith('treat')]
base = CoxnetSurvivalAnalysis(l1_ratio=0.9, tol=1e-20, max_iter=10000, alphas=[0])
base.fit(xtrain[baseline_feature], ytrain)
base_risk_scores = base.predict(xval[baseline_feature])
base_auc, base_mean = cumulative_dynamic_auc(ytrain, yval, base_risk_scores, va_times)


with open(output, 'w') as fo:
    fo.write(f'## trainnig \n')
    fo.write(f'# mean cindex: {grids_cindex.cv_results_["mean_test_score"]} \n')
    fo.write(f'# best alpha: {grids_cindex.best_params_["estimator__alphas"]} \n')
    fo.write(f'## validation \n')
    fo.write(f'# cindex:{cph.score(xval, yval)} \n')
    fo.write(f'# features in: {len(features)}+age+treatment, features out: {non_zero} \n')
    fo.write(f'## baseline \n')
    fo.write(f'## features: {baseline_feature} \n')
    fo.write(f'# cindex: {base.score(xval[baseline_feature], yval)} \n')
    fo.write(f'# base_mean_auc: {base_mean} \n')


pd.DataFrame({'AUC':auc, 'base':base_auc, 'time':va_times, 'mean':[mean_auc]*len(auc)}).to_csv(output, index=False, mode='a')
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

sample_info = pd.read_csv(snakemake.input[-1])

fo = snakemake.output[0]

event = snakemake.params[0]
days = snakemake.params[1]

label = event.split("_")[0].lower()

de_gs_uni = pd.read_csv(f'denovo/{label}_age_trm_gs/cox_{event.split("_")[0]}_uni_analysis.txt', comment='#')
de_gs_1run = pd.read_csv(f'denovo/{label}_age_trm_gs/cox_{event.split("_")[0]}_1run_analysis.txt', comment='#')
de_gs_ens = pd.read_csv(f'denovo/{label}_age_trm_gs/cox_{event.split("_")[0]}_ens_analysis.txt', comment='#')

v32_gs_uni = pd.read_csv(f'normal/{label}_age_trm_gs/cox_{event.split("_")[0]}_uni_analysis.txt', comment='#')
v32_gs_1run = pd.read_csv(f'normal/{label}_age_trm_gs/cox_{event.split("_")[0]}_1run_analysis.txt', comment='#')
v32_gs_ens = pd.read_csv(f'normal/{label}_age_trm_gs/cox_{event.split("_")[0]}_ens_analysis.txt', comment='#')

v27_gs_uni = pd.read_csv(f'v27/{label}_age_trm_gs/cox_{event.split("_")[0]}_uni_analysis.txt', comment='#')
v27_gs_1run = pd.read_csv(f'v27/{label}_age_trm_gs/cox_{event.split("_")[0]}_1run_analysis.txt', comment='#')
v27_gs_ens = pd.read_csv(f'v27/{label}_age_trm_gs/cox_{event.split("_")[0]}_ens_analysis.txt', comment='#')

de_ts_uni = pd.read_csv(f'denovo/{label}_age_trm_ts/cox_{event.split("_")[0]}_uni_analysis.txt', comment='#')
de_ts_1run = pd.read_csv(f'denovo/{label}_age_trm_ts/cox_{event.split("_")[0]}_1run_analysis.txt', comment='#')
de_ts_ens = pd.read_csv(f'denovo/{label}_age_trm_ts/cox_{event.split("_")[0]}_ens_analysis.txt', comment='#')

v32_ts_uni = pd.read_csv(f'normal/{label}_age_trm_ts/cox_{event.split("_")[0]}_uni_analysis.txt', comment='#')
v32_ts_1run = pd.read_csv(f'normal/{label}_age_trm_ts/cox_{event.split("_")[0]}_1run_analysis.txt', comment='#')
v32_ts_ens = pd.read_csv(f'normal/{label}_age_trm_ts/cox_{event.split("_")[0]}_ens_analysis.txt', comment='#')


if 'RFi' in event:
    sample_info = sample_info[['RFi_days', 'RFi_event', 'RFS_all_event']]
    sample_info = sample_info[sample_info.RFi_days.notna()].reset_index(drop=True)
    sample_info = sample_info[sample_info.apply(lambda x: False if (x['RFi_event'] == 0 and x['RFi_days'] < 600) or (x['RFi_event'] == False and x['RFS_all_event'] == True) else True, axis = 1)].reset_index(drop=True)
    sample_info.drop(['RFS_all_event'], axis=1, inplace=True)
else:
    sample_info = sample_info[[days, event]]
    sample_info = sample_info[sample_info[event].notna()].reset_index(drop=True)
    sample_info = sample_info[sample_info[days].notna()].reset_index(drop=True)

_, test = train_test_split(sample_info, test_size=0.2, random_state=954)
va_times = np.arange(np.percentile(sample_info[days], [5, 95])[0], np.percentile(sample_info[days], [5, 95])[1], 60)

os_type = {'names':('events', 'time'), 'formats':('?', '<f8')}
test = np.array([(a, b) for a, b in zip(test[event].astype(bool), test[days])], dtype=os_type)

fg = plt.figure(figsize=(18, 7), dpi=128)
ax = fg.add_subplot(111)

ax.plot(de_gs_uni.time, de_gs_uni.AUC, marker=".", color='lightcoral',label='de_gs_uni')
ax.plot(de_gs_1run.time, de_gs_1run.AUC, marker=".", color='firebrick',label='de_gs_1run')
ax.plot(de_gs_ens.time, de_gs_ens.AUC, marker=".", color='darkred', label='de_gs_ens')

ax.plot(v32_gs_uni.time, v32_gs_uni.AUC, marker=".", color='chocolate', label='v32_gs_uni')
ax.plot(v32_gs_1run.time, v32_gs_1run.AUC, marker=".", color='peru', label='v32_gs_1run')
ax.plot(v32_gs_ens.time, v32_gs_ens.AUC, marker=".", color='darkorange',label='v32_gs_ens')

ax.plot(v27_gs_uni.time, v27_gs_uni.AUC, marker=".", color='darkseagreen', label='v27_gs_uni')
ax.plot(v27_gs_1run.time, v27_gs_1run.AUC, marker=".", color='forestgreen', label='v27_gs_1run')
ax.plot(v27_gs_ens.time, v27_gs_ens.AUC, marker=".", color='seagreen', label='v27_gs_ens')

ax.plot(de_ts_uni.time, de_ts_uni.AUC, marker=".", color='teal', label='de_ts_uni')
ax.plot(de_ts_1run.time, de_ts_1run.AUC, marker=".", color='skyblue', label='de_ts_1run')
ax.plot(de_ts_ens.time, de_ts_ens.AUC, marker=".", color='steelblue', label='de_ts_ens')

ax.plot(v32_ts_uni.time, v32_ts_uni.AUC, marker=".", color='navy',label='v32_ts_uni')
ax.plot(v32_ts_1run.time, v32_ts_1run.AUC, marker=".", color='slateblue', label='v32_ts_1run')
ax.plot(v32_ts_ens.time, v32_ts_ens.AUC, marker=".", color='purple', label='v32_ts_ens')

ax.plot(v32_ts_ens.time, v32_ts_ens.base, marker=".", color='silver', label='Baseline')

plt.xticks(fontsize=15)
plt.yticks(fontsize=15)

ax.set_xticklabels([])
ax.set_ylabel("Time-Dependent AUC", fontsize=18)
plt.legend(bbox_to_anchor=(1, 1), loc=2, borderaxespad=0, fontsize=12)


cum_events = []
cum_no_events = []
for i_time in va_times:
    i_events = len([i_event for i_event in test if (i_event[0] and i_event[1] <= i_time)])
    i_no_events = len([i_event for i_event in test if (i_event[1] > i_time)])
    cum_events.append(i_events)
    cum_no_events.append(i_no_events)

plt.table(cellText=[cum_events, cum_no_events],
          rowLabels=['Events (Cumulative)', 'No Events (-Cumulative)'],
          colLabels=va_times,
          cellLoc='center',
          loc='bottom')


plt.savefig(fo)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split

sample_info = pd.read_csv(snakemake.input[-1])

fo = snakemake.output[0]

event = snakemake.params[0]
days = snakemake.params[1]

if event == 'Ki67_censored_event':
    label = 'ki67_censored'
else:
    label = 'ki67_uncensored'

de_gs_uni = pd.read_csv(f'denovo/{label}_gs/cox_{event.split("_")[0]}_uni_analysis.txt', comment='#')
de_gs_1run = pd.read_csv(f'denovo/{label}_gs/cox_{event.split("_")[0]}_1run_analysis.txt', comment='#')
de_gs_ens = pd.read_csv(f'denovo/{label}_gs/cox_{event.split("_")[0]}_ens_analysis.txt', comment='#')

v32_gs_uni = pd.read_csv(f'normal/{label}_gs/cox_{event.split("_")[0]}_uni_analysis.txt', comment='#')
v32_gs_1run = pd.read_csv(f'normal/{label}_gs/cox_{event.split("_")[0]}_1run_analysis.txt', comment='#')
v32_gs_ens = pd.read_csv(f'normal/{label}_gs/cox_{event.split("_")[0]}_ens_analysis.txt', comment='#')

de_ts_uni = pd.read_csv(f'denovo/{label}_ts/cox_{event.split("_")[0]}_uni_analysis.txt', comment='#')
de_ts_1run = pd.read_csv(f'denovo/{label}_ts/cox_{event.split("_")[0]}_1run_analysis.txt', comment='#')
de_ts_ens = pd.read_csv(f'denovo/{label}_ts/cox_{event.split("_")[0]}_ens_analysis.txt', comment='#')

v32_ts_uni = pd.read_csv(f'normal/{label}_ts/cox_{event.split("_")[0]}_uni_analysis.txt', comment='#')
v32_ts_1run = pd.read_csv(f'normal/{label}_ts/cox_{event.split("_")[0]}_1run_analysis.txt', comment='#')
v32_ts_ens = pd.read_csv(f'normal/{label}_ts/cox_{event.split("_")[0]}_ens_analysis.txt', comment='#')

if 'RFi' in event:
    sample_info = sample_info[['RFi_days', 'RFi_event', 'RFS_all_event']]
    sample_info = sample_info[sample_info.RFi_days.notna()].reset_index(drop=True)
    sample_info = sample_info[sample_info.apply(lambda x: False if (x['RFi_event'] == 0 and x['RFi_days'] < 600) or (x['RFi_event'] == False and x['RFS_all_event'] == True) else True, axis = 1)].reset_index(drop=True)
    sample_info.drop(['RFS_all_event'], axis=1, inplace=True)
else:
    sample_info = sample_info[[days, event]]
    sample_info = sample_info[sample_info[event].notna()].reset_index(drop=True)
    sample_info = sample_info[sample_info[days].notna()].reset_index(drop=True)

_, test = train_test_split(sample_info, test_size=0.2, random_state=954)
va_times = np.arange(np.percentile(sample_info[days], [5, 95])[0], np.percentile(sample_info[days], [5, 95])[1], 60)

os_type = {'names':('events', 'time'), 'formats':('?', '<f8')}
test = np.array([(a, b) for a, b in zip(test[event].astype(bool), test[days])], dtype=os_type)

fg = plt.figure(figsize=(18, 7), dpi=128)
ax = fg.add_subplot(111)

ax.plot(de_gs_uni.time, de_gs_uni.AUC, marker=".", color='lightcoral',label='de_gs_uni')
ax.plot(de_gs_1run.time, de_gs_1run.AUC, marker=".", color='firebrick',label='de_gs_1run')
ax.plot(de_gs_ens.time, de_gs_ens.AUC, marker=".", color='darkred', label='de_gs_ens')

ax.plot(v32_gs_uni.time, v32_gs_uni.AUC, marker=".", color='chocolate', label='v32_gs_uni')
ax.plot(v32_gs_1run.time, v32_gs_1run.AUC, marker=".", color='peru', label='v32_gs_1run')
ax.plot(v32_gs_ens.time, v32_gs_ens.AUC, marker=".", color='darkorange',label='v32_gs_ens')

ax.plot(de_ts_uni.time, de_ts_uni.AUC, marker=".", color='teal', label='de_ts_uni')
ax.plot(de_ts_1run.time, de_ts_1run.AUC, marker=".", color='skyblue', label='de_ts_1run')
ax.plot(de_ts_ens.time, de_ts_ens.AUC, marker=".", color='steelblue', label='de_ts_ens')

ax.plot(v32_ts_uni.time, v32_ts_uni.AUC, marker=".", color='navy',label='v32_ts_uni')
ax.plot(v32_ts_1run.time, v32_ts_1run.AUC, marker=".", color='slateblue', label='v32_ts_1run')
ax.plot(v32_ts_ens.time, v32_ts_ens.AUC, marker=".", color='purple', label='v32_ts_ens')

ax.plot(v32_ts_ens.time, v32_ts_ens.base, marker=".", color='silver', label='Baseline')

plt.xticks(fontsize=15)
plt.yticks(fontsize=15)

ax.set_xticklabels([])
ax.set_ylabel("Time-Dependent AUC", fontsize=18)
plt.legend(bbox_to_anchor=(1, 1), loc=2, borderaxespad=0, fontsize=12)


cum_events = []
cum_no_events = []
for i_time in va_times:
    i_events = len([i_event for i_event in test if (i_event[0] and i_event[1] <= i_time)])
    i_no_events = len([i_event for i_event in test if (i_event[1] > i_time)])
    cum_events.append(i_events)
    cum_no_events.append(i_no_events)

plt.table(cellText=[cum_events, cum_no_events],
          rowLabels=['Events (Cumulative)', 'No Events (-Cumulative)'],
          colLabels=va_times,
          cellLoc='center',
          loc='bottom')


plt.savefig(fo)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split


uni = pd.read_csv(snakemake.input[0], comment='#')
onerun = pd.read_csv(snakemake.input[1], comment='#')
ens = pd.read_csv(snakemake.input[2], comment='#')
sample_info = pd.read_csv(snakemake.input[3])

fo = snakemake.output[0] 

event = snakemake.params[0]
days = snakemake.params[1]

with open(snakemake.input[0]) as fi:
    for l in fi:
        if l.startswith('# base_mean_auc:'):
            base_mean = float(l.split(':')[1].strip())

label = os.path.basename(fo).split('-')

if 'gs' in label[1]:
    data_set = 'gene'
else:
    data_set = 'transcript'

if 'normal' in label[0]:
    version = 'V32'
elif 'denovo' in label[0]:
    version = 'de-novo'
else:
    version = 'V27'

if 'ki67_censored' in label[1]:
    label = 'Ki67_censored'
elif 'ki67_uncensored' in label[1]:
    label = 'Ki67_uncensored'
else:
    label = event.split("_")[0]


if 'RFi' in event:
    sample_info = sample_info[['RFi_days', 'RFi_event', 'RFS_all_event']]
    sample_info = sample_info[sample_info.RFi_days.notna()].reset_index(drop=True)
    sample_info = sample_info[sample_info.apply(lambda x: False if (x['RFi_event'] == 0 and x['RFi_days'] < 600) or (x['RFi_event'] == False and x['RFS_all_event'] == True) else True, axis = 1)].reset_index(drop=True)
    sample_info.drop(['RFS_all_event'], axis=1, inplace=True)
else:
    sample_info = sample_info[[days, event]]
    sample_info = sample_info[sample_info[event].notna()].reset_index(drop=True)
    sample_info = sample_info[sample_info[days].notna()].reset_index(drop=True)

_, test = train_test_split(sample_info, test_size=0.2, random_state=954)
va_times = np.arange(np.percentile(sample_info[days], [5, 95])[0], np.percentile(sample_info[days], [5, 95])[1], 60)

os_type = {'names':('events', 'time'), 'formats':('?', '<f8')}
test = np.array([(a, b) for a, b in zip(test[event].astype(bool), test[days])], dtype=os_type)



fg = plt.figure(figsize=(15, 5), dpi=128)
ax = fg.add_subplot(111)

ax.plot(uni.time, uni.AUC, marker=".", color='darkorange', label='Univariate analysis')
ax.axhline(uni['mean'][0], linestyle="--", color='bisque', label='Accumulated AUC = %.3f'%uni['mean'][0])

ax.plot(onerun.time, onerun.AUC, marker=".", color='dodgerblue', label='One run analysis')
ax.axhline(onerun['mean'][0], linestyle="--", color='lightskyblue', label='Accumulated AUC = %.3f'%onerun['mean'][0])

ax.plot(ens.time, ens.AUC, marker=".", color='forestgreen', label='Ensemble method')
ax.axhline(ens['mean'][0], linestyle="--", color='palegreen', label='Accumulated AUC = %.3f'%ens['mean'][0])

ax.plot(ens.time, ens.base, marker=".", color='darkgrey', label='Baseline')
ax.axhline(base_mean, linestyle="--", color='silver', label='Accumulated AUC = %.3f'%base_mean)

plt.xticks(fontsize=15)
plt.yticks(fontsize=15)

ax.set_xticklabels([]) 
ax.set_ylabel("Time-Dependent AUC", fontsize=18)
# ax.set_title(f'{label} label using {data_set} set assembled by {version} annotation file', fontsize=22)
plt.legend(fontsize=12)

cum_events = []
cum_no_events = []
for i_time in va_times:
    i_events = len([i_event for i_event in test if (i_event[0] and i_event[1] <= i_time)])
    i_no_events = len([i_event for i_event in test if (i_event[1] > i_time)])
    cum_events.append(i_events)
    cum_no_events.append(i_no_events)

plt.table(cellText=[cum_events, cum_no_events],
          rowLabels=['Events (Cumulative)', 'No Events (-Cumulative)'],
          colLabels=va_times,
          cellLoc='center',
          loc='bottom')


plt.savefig(fo)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from scipy import stats
from lifelines import CoxPHFitter
import pandas as pd
import numpy as np

event = snakemake.params[0]
days = snakemake.params[1]

def survival(df, t):
    try:
        return float(CoxPHFitter().fit(df[[t, days, event]], days, event_col=event).summary['p'])
    except:
        return np.nan

data = pd.read_csv(snakemake.input[0]).set_index('ID')
data = data.T.reset_index().rename(columns={'index':'SAMPLE'})
data.columns.names = [None]

if 'RFi' in event:
    data = pd.merge(pd.read_csv(snakemake.input[1])[['SAMPLE', 'RFi_days', 'RFi_event', 'RFS_all_event']], data, on='SAMPLE')
    data = data[data.RFi_days.notna()].reset_index(drop=True)
    data = data[data.apply(lambda x: False if (x['RFi_event'] == 0 and x['RFi_days'] < 600) or (x['RFi_event'] == False and x['RFS_all_event'] == True) else True, axis = 1)].reset_index(drop=True)
    data.drop(['RFS_all_event'], axis=1, inplace=True)
else:
    data = pd.merge(pd.read_csv(snakemake.input[1])[['SAMPLE', days, event]], data, on='SAMPLE')
    data = data[data[event].notna()].reset_index(drop=True)
    data = data[data[days].notna()].reset_index(drop=True)

ts = data.columns[3:]
ts_sig = {}
for t in ts:
    ts_sig[t] = survival(data, t)

ts_df = pd.DataFrame.from_dict(ts_sig, orient='index', columns=['p-val']).reset_index().rename(columns={'index':'fetures'})
ts_df.to_csv(snakemake.output[0], index=False)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pandas as pd

samples = list(set([i.split('/')[-1].split('.')[0] for i in snakemake.input]))
table = snakemake.output

def get_info(l, *info):
    try:
        info_dic =  {i.strip().split(' ')[0]:i.strip().split(' ')[1].strip('"') for i in l.split('\t')[-1][:-3].split(';')}
    except:
        info_dic = {i.strip().split(' ')[0]:i.strip().split(' ')[1].strip('"') for i in l.split('\t')[-1][:-2].split(';')}
    return ','.join([info_dic[i] for i in info])

def extract_t(sp, wdir):
    tmp = pd.read_csv(wdir+'/'+sp+'/'+sp+'.gtf', sep='\t', comment='#', header=None)
    tmp = tmp[tmp[2] == 'transcript'][[8]]
    tmp[['gene_id', 'transcript_id', sp]] = pd.DataFrame(tmp[8].apply(lambda x: get_info(x, 'gene_id', 'transcript_id', 'FPKM')))[8].str.split(',', expand=True)
    tmp['ID'] = tmp.apply(lambda x: x['gene_id']+'_'+x['transcript_id'], axis=1)
    tmp[sp] = tmp[sp].astype(float)
    tmp.drop(['gene_id', 'transcript_id', 8], axis=1, inplace=True)
    tmp.drop_duplicates(subset='ID', keep='first', inplace=True)
    return tmp

def extract_g(sp, wdir):
    tmp = pd.read_csv(wdir+'/'+sp+'/'+sp+'.tsv', sep='\t')[['Gene ID', 'Gene Name', 'FPKM']]
    tmp.columns = ['ID', 'GENE', sp]
    tmp['ID'] = tmp.apply(lambda x: x['GENE']+'_'+x['ID'], axis=1)
    tmp.drop('GENE', axis=1, inplace=True)
    tmp.drop_duplicates(subset='ID', keep='first', inplace=True)
    return tmp


def extract_transcript(t):
    wdir = t.split('/')[0]
    data = extract_t(samples[0], wdir)
    for sp in samples[1:]:
        data = pd.merge(data, extract_t(sp, wdir), on='ID', how='left')
    data.to_csv(t, index=False)


def extract_gene(t):
    wdir = t.split('/')[0]
    data = extract_g(samples[0], wdir)
    for sp in samples[1:]:
        data = pd.merge(data, extract_g(sp, wdir), on='ID', how='left')
    data.to_csv(t, index=False)

for t in table:
    if 'transcripts' in t:
        extract_transcript(t)
    if 'genes' in t:
        extract_gene(t)
1
2
3
ts <- read.csv(snakemake@input[[1]], header=T, sep = ',')
ts$p.adjust <-p.adjust(ts$p.val,method="fdr", n=length(ts$p.val))
write.csv(ts, file=snakemake@output[[1]], row.names = FALSE, quote=FALSE)
R From line 1 of scripts/FDR.R
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt



comment = {}
for j in snakemake.input: 
    with open(j) as f:
        for l in f:
            if l.startswith('#batch'):
                i = l[6:].rstrip()
                comment[j+'_cindex'] = []
                comment[j+'_alpha'] = []
                comment[j+'_features'] = []
            if l.startswith('#cindex'):
                comment[j+'_cindex'].append(float(l.split(',')[0][8:]))
                comment[j+'_alpha'].append(float(l.split(',')[1][7:-2]))
            if l.startswith('#fea'):
                comment[j+'_features'].append(l[10:-1])
comment = pd.DataFrame(comment)

cindex = []
for j in snakemake.input:
    cindex += list(comment[j+'_cindex'])
cindex_ = np.quantile(cindex, 0.5)

batches = pd.read_csv(snakemake.input[0], index_col=[0], comment='#')
batches.columns = [snakemake.input[0]]
for i in snakemake.input[1:]:
    f = pd.read_csv(i, index_col=[0], comment='#')
    f.columns = [i]
    batches = pd.concat([batches, f], axis=1)

for j in snakemake.input:
    feature_drop = ','.join(comment[comment[j+'_cindex'] < cindex_][j+'_features']).split(',')
    if len(feature_drop) > 1:
        batches.loc[[x for x in feature_drop if x != ''], j] = 0

# genes: top 80%, transcripts: top 70%
batches[batches==0]=np.nan
batches['MEAN'] = batches.apply(np.mean, axis=1)
batches['Count'] = batches.apply(lambda x: x.notna().sum(), axis=1)-1
features = batches[batches['Count']>np.ceil(max(batches['Count'])*0.1)]
#features = batches[batches['Count']>0]


pd.DataFrame(features.index, columns=['Feature']).to_csv(snakemake.output[0], index=False)
batches.to_csv(snakemake.output[1], index=True)
with open(snakemake.output[2], 'w') as f:
    f.write('#mean:{a},median:{b}\n'.format(a=np.mean(cindex), b=cindex_))
comment.to_csv(snakemake.output[2], index=False, mode='a')
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import pandas as pd
import numpy as np
from sksurv.linear_model import CoxnetSurvivalAnalysis
from sksurv.preprocessing import OneHotEncoder
from sksurv.metrics import cumulative_dynamic_auc
from sklearn.model_selection import GridSearchCV, KFold, train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, MinMaxScaler


event = snakemake.params[0]
days = snakemake.params[1]


def load_data(dataP, spInfo):
    data = pd.read_csv(dataP).set_index('ID')
    sample_info = pd.read_csv(spInfo)
    data.drop(['MEAN','VAR'], axis=1, inplace=True)
    data = data.T.reset_index().rename(columns={'index':'SAMPLE'})
    data.columns.names = [None]
    if 'RFi' in event:
        data = pd.merge(sample_info[['SAMPLE', 'RFi_days', 'RFi_event', 'RFS_all_event', 'age', 'treatGroup']], data, on='SAMPLE')
        data = data[data.RFi_days.notna()].reset_index(drop=True)
        data = data[data.apply(lambda x: False if (x['RFi_event'] == 0 and x['RFi_days'] < 600) or (x['RFi_event'] == False and x['RFS_all_event'] == True) else True, axis = 1)].reset_index(drop=True)
        data.drop(['RFS_all_event'], axis=1, inplace=True)
    else:
        data = pd.merge(sample_info[['SAMPLE', days, event, 'age', 'treatGroup']], data, on='SAMPLE')
        data = data[data[event].notna()].reset_index(drop=True)
        data = data[data[days].notna()].reset_index(drop=True)
    x, y = data.iloc[::, 3:], data.iloc[::, 1:3]
    treatment = x.treatGroup.astype('category')
    x.drop('treatGroup', axis=1, inplace=True)
    x = pd.DataFrame(StandardScaler().fit_transform(x), columns=x.columns).dropna(axis=1)
    x['treatGroup'] = treatment
    x = OneHotEncoder().fit_transform(x).replace(np.nan, 0)
    return x, y

x, y = load_data(snakemake.input[0], snakemake.input[1])
os_type = {'names':('events', 'time'), 'formats':('?', '<f8')}
y = np.array([(a, b) for a, b in zip(y[event].astype(bool), y[days])], dtype=os_type)

xtrain, xval, ytrain, yval = train_test_split(x, y, test_size=0.2, random_state=954)

templete = CoxnetSurvivalAnalysis(l1_ratio=0.9, alpha_min_ratio=0.01, n_alphas=20)
templete.fit(xtrain, ytrain)
cv = KFold(n_splits=5, shuffle=True, random_state=0)
grids = GridSearchCV(
    make_pipeline(CoxnetSurvivalAnalysis(l1_ratio=0.9)),
    param_grid={"coxnetsurvivalanalysis__alphas": [[v] for v in templete.alphas_]},
    cv=cv,
    error_score=0.5,
    n_jobs=16).fit(xtrain, ytrain)

best_model = grids.best_estimator_.named_steps["coxnetsurvivalanalysis"]
best_coefs = pd.DataFrame(
    best_model.coef_,
    index=xtrain.columns,
    columns=['coefficient']
)

pd.DataFrame(grids.cv_results_).to_csv(snakemake.output[0], index=False)
best_coefs.query("coefficient != 0").to_csv(snakemake.output[1])
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import pandas as pd
import numpy as np
import warnings
from sksurv.linear_model import CoxnetSurvivalAnalysis, CoxPHSurvivalAnalysis
from sksurv.preprocessing import OneHotEncoder
from sksurv.metrics import cumulative_dynamic_auc
from sklearn.model_selection import GridSearchCV, KFold, train_test_split, cross_val_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from matplotlib import pyplot as plt

def load_data(dataP, spInfo, features):
    data = pd.read_csv(dataP).set_index('ID')
    sample_info = pd.read_csv(spInfo)
#    data = np.log2(data + 1)
#    data['MEAN'] = data.apply(np.mean, axis=1)
#    data['VAR'] = data.apply(np.var, axis=1)
#    data = data[data['MEAN']>0.01]
#    data = data.sort_values(by='VAR').iloc[-int(data.shape[0]*0.8):]
    data.drop(['MEAN','VAR'], axis=1, inplace=True)
    data = data.T.reset_index().rename(columns={'index':'SAMPLE'})
    data.columns.names = [None]
    if isinstance(features, pd.DataFrame):
        data = data[['SAMPLE']+list(features.Feature)]
    if 'RFi' in event:
        data = pd.merge(sample_info[['SAMPLE', 'RFi_days', 'RFi_event', 'RFS_all_event', 'age', 'treatGroup']], data, on='SAMPLE')
        data = data[data.RFi_days.notna()].reset_index(drop=True)
        data = data[data.apply(lambda x: False if (x['RFi_event'] == 0 and x['RFi_days'] < 600) or (x['RFi_event'] == False and x['RFS_all_event'] == True) else True, axis = 1)].reset_index(drop=True)
        data.drop(['RFS_all_event'], axis=1, inplace=True)
    else:
        data = pd.merge(sample_info[['SAMPLE', days, event, 'age', 'treatGroup']], data, on='SAMPLE')
        data = data[data[event].notna()].reset_index(drop=True)
        data = data[data[days].notna()].reset_index(drop=True)
    x, y = data.iloc[::, 3:], data.iloc[::, 1:3]
    treatment = x.treatGroup.astype('category')
    x.drop('treatGroup', axis=1, inplace=True)
    x = pd.DataFrame(StandardScaler().fit_transform(x), columns=x.columns).dropna(axis=1) 
    x['treatGroup'] = treatment
    x = OneHotEncoder().fit_transform(x).replace(np.nan, 0)
    return x, y

def gcv_plots(grids, figPath):
    cv_results = pd.DataFrame(grids.cv_results_)
    alphas = cv_results.param_coxnetsurvivalanalysis__alphas.map(lambda x: x[0])
    mean = cv_results.mean_test_score
    std = cv_results.std_test_score
    fig, ax = plt.subplots(figsize=(9, 6))
    ax.plot(alphas, mean)
    ax.fill_between(alphas, mean - std, mean + std, alpha=.15)
    ax.set_xscale("log")
    ax.set_ylabel("concordance index")
    ax.set_xlabel("alpha")
    ax.axvline(grids.best_params_["coxnetsurvivalanalysis__alphas"], c="C1")
    ax.axhline(0.5, color="grey", linestyle="--")
    ax.grid(True)
    fig.savefig(figPath)
    plt.close()
def cox_linear(x_train, batchid, indexes, threads):
    os_type = {'names':('events', 'time'), 'formats':('?', '<f8')}
    kept = ['age', 'treatGroup=EndoImmu', 'treatGroup=Endo', 'treatGroup=EndoCyto', 'treatGroup=EndoCytoImmu', 'treatGroup=CytoImmu']
    y_train2 = np.array([(m, n) for m, n in zip(ytrain[event].astype(bool), ytrain[days])], dtype=os_type)
    base = CoxPHSurvivalAnalysis()
    base_score = cross_val_score(base, xtrain[kept], y_train2, cv=indexes, n_jobs=threads)
    gcv = GridSearchCV(
        make_pipeline(CoxPHSurvivalAnalysis()),
        param_grid={"coxphsurvivalanalysis__alpha": [0]},
        cv=indexes,
        error_score=np.mean(base_score),
        n_jobs=threads).fit(xtrain[x_train.columns], y_train2)
    best_model = gcv.best_estimator_.named_steps["coxphsurvivalanalysis"]
    selection_coefs = pd.DataFrame(
        best_model.coef_,
        index=x_train.columns,
        columns=[batchid]
    )
    return selection_coefs, np.mean(base_score), gcv 

def coefs_batch(x_train, batchid, indexes, threads=8):

    os_type = {'names':('events', 'time'), 'formats':('?', '<f8')}
    #x_train = x_train.iloc[np.append(indexes[0][0], indexes[0][1]),:]
    #y_train = ytrain.iloc[np.append(indexes[0][0], indexes[0][1]),:]
    #y_train = np.array([(m, n) for m, n in zip(y_train[event].astype(bool), y_train[days])], dtype=os_type)
    y_train2 = np.array([(m, n) for m, n in zip(ytrain[event].astype(bool), ytrain[days])], dtype=os_type)

    # test_run = CoxnetSurvivalAnalysis(l1_ratio=0.9, alpha_min_ratio=0.05, n_alphas=30)
    # test_run.fit(x_train, y_train)
    # alphas = 10. ** np.linspace(-2.3, -1.8, 8)
    gcv = GridSearchCV(
        make_pipeline(CoxnetSurvivalAnalysis(l1_ratio=0.9, tol=1e-10, max_iter=10000)),
        # param_grid={"coxnetsurvivalanalysis__alphas": [[v] for v in alphas]},
        param_grid={"coxnetsurvivalanalysis__alphas": [[v] for v in [0.02, 0.03, 0.04, 0.05, 0.06, 0.07]]},
        cv=indexes,
        error_score=0.6,
        n_jobs=threads).fit(xtrain[x_train.columns], y_train2)
    best_model = gcv.best_estimator_.named_steps["coxnetsurvivalanalysis"]
    best_coefs = pd.DataFrame(
        best_model.coef_,
        index=x_train.columns,
        columns=[batchid]
    )
    return best_coefs.iloc[6:,:], gcv


def training_batch(x_train, y_train, iteration):
    if 'RFi' in event:
        kept = ['age', 'treatGroup=Endo', 'treatGroup=EndoCyto', 'treatGroup=EndoCytoImmu']
    else:
        kept = ['age', 'treatGroup=EndoImmu', 'treatGroup=Endo', 'treatGroup=EndoCyto', 'treatGroup=EndoCytoImmu', 'treatGroup=CytoImmu'] 
    xa = x_train[kept]
    xb = x_train.drop(kept, axis=1)
    yTrue = y_train[y_train[event] == True].reset_index()
    lq = int(yTrue.describe().iloc[4,0]) # lower quartile
    yFalse = y_train[y_train.apply(lambda x: True if x[event] == False and x[days] > lq else False, axis=1)].reset_index()
    if yTrue.shape[0] > yFalse.shape[0]:
        majority = True
    else:
        majority = False
    batches = pd.DataFrame(index=xb.columns)

    for i in range(iteration):
        truePart = []
        falsePart = []
        indexes = []
        yTrueKF = KFold(n_splits=5, shuffle=True, random_state=1723).split(yTrue)
        yFalseKF = KFold(n_splits=5, shuffle=True, random_state=1938).split(yFalse)
        for a, j in yTrueKF:
            np.random.shuffle(j) # random part
            if majority:
                truePart.append(yTrue.iloc[j]['index'][:int(yFalse.shape[0]/5)])
            else:
                truePart.append(yTrue.iloc[j]['index'])

        for a, j in yFalseKF:
            np.random.shuffle(j) # random part
            if not majority:
                falsePart.append(yFalse.iloc[j]['index'][:int(yTrue.shape[0]/5)])
            else:
                falsePart.append(yFalse.iloc[j]['index'])
        for a in range(5):
            to_del = [0,1,2,3,4]
            to_del.pop(a)
            trainPart = np.array([], dtype=int)
            valPart = np.append(truePart[a], falsePart[a])
            for j in to_del:
                tep = np.append(truePart[j], falsePart[j])
                trainPart = np.append(trainPart, tep)
            indexes.append([trainPart, valPart])
        if x_train.shape[1]//int(yTrue.shape[0]*1.6) == 0:
            kf = [[np.nan, 0]]
        else:
            kf = KFold(n_splits=x_train.shape[1]//int(yTrue.shape[0]*1.6)+1, shuffle=True).split(xb.T) # splits for features
        batch = pd.DataFrame(columns=['batch'+str(i+1)])
        with open(out, 'a') as f:
            f.write('#batch'+str(i+1)+'\n')
            for train_index, test_index in kf:
                if isinstance(test_index, int):
                    X_sub = xb.T
                else:
                    X_sub = xb.T.iloc[test_index,:]
                xc = pd.concat([xa, X_sub.T], axis=1)
                # xc = xc.iloc[np.append(indexes[0][0], indexes[0][1]),:]
                b, gcv = coefs_batch(xc, 'batch'+str(i+1), indexes, threads=16)
                f.write('#cindex:'+str(gcv.best_score_)+',alpha:'+str(gcv.best_params_["coxnetsurvivalanalysis__alphas"])+'\n')
                # f.write('#baseindex={m}, cindex={n}\n'.format(m=baseScore, n=gcv.best_score_))
                f.write('#features:'+','.join(b[b['batch'+str(i+1)]!=0].index)+'\n')
                # f.write('# '+str(indexes)+'\n')
                batch = pd.concat([batch, b]) # vertical
                # gcv_plots(gcv, '/home/chixu/isoform/denovo/transcripts/figs/'+out.split('/')[-1].split('.')[0]+'_'+str(len(batch.index))+'.jpg')
                #pd.DataFrame(gcv.cv_results_).to_csv('/home/chixu/isoform/normal/genes/figs/'+out.split('/')[-1].split('.')[0]+'_'+str(len(batch.index))+'.csv')
        batches = pd.concat([batches, batch], axis=1) # horizontal
    return batches



event = snakemake.params[0]
days = snakemake.params[1]


if snakemake.input[1] == snakemake.input[2]:
    features = ''
else:
    features = pd.read_csv(snakemake.input[2])    
x, y = load_data(snakemake.input[0], snakemake.input[1], features)
xtrain, xval, ytrain, yval = train_test_split(x, y, test_size=0.2, random_state=954)
out = snakemake.output[0]
xtrain = xtrain.reset_index(drop=True)
ytrain = ytrain.reset_index(drop=True)
batch = training_batch(xtrain, ytrain, iteration=1)
# batch = training_batch(xtrain, ytrain, [21])
batch.to_csv(out, mode='a')
59
60
shell:
    "stringtie {input.bam} -l {wildcards.sp} -p 2 -G {input.gtf} -o {output.gtf}"
70
71
72
shell:
    "ls pre_assembly/*.gtf | head -10 > pre_assembly/mergelist.txt; "
    "stringtie --merge -G {input.gtf} -c 1 -o {output.gtf} pre_assembly/mergelist.txt"
85
86
shell:
    "stringtie {input.bam} -l {wildcards.sp} -p 2 -e -G {input.gtf} -o {output.gtf} -A {output.csv} -C {output.cov}"
95
96
script:
    'scripts/extract.py'
119
120
script:
    "scripts/pre_selection.py"
SnakeMake From line 119 of master/Snakefile
152
153
script:
    "scripts/filter.py"
SnakeMake From line 152 of master/Snakefile
186
187
script:
     'scripts/conclude1.py'
SnakeMake From line 186 of master/Snakefile
197
198
script:
    'scripts/conclude2.py'
SnakeMake From line 197 of master/Snakefile
215
216
script:
    "scripts/one_run.py"
SnakeMake From line 215 of master/Snakefile
233
234
script:
    "scripts/cox_unitest.py"
SnakeMake From line 233 of master/Snakefile
247
248
script:
    "scripts/FDR.R"
SnakeMake From line 247 of master/Snakefile
258
259
260
261
262
263
264
265
266
    shell:
        ''

rule cox_analysis:
    input:
        data = lambda wc: '{which}/genes.csv' if 'gs' in wc.path else '{which}/transcripts.csv',
        info = config["sample_info"],
    output:
        '{which}/{path}/cox_'+EVENT.split("_")[0]+'_{method}_analysis.txt'
SnakeMake From line 258 of master/Snakefile
273
274
script:
    'scripts/cox_analysis.py'
SnakeMake From line 273 of master/Snakefile
288
289
script:
    'scripts/cox_plot.py'
SnakeMake From line 288 of master/Snakefile
304
305
script:
    'scripts/cox_plot2.py'
SnakeMake From line 304 of master/Snakefile
321
322
script:
    'scripts/cox_plot3.py'
SnakeMake From line 321 of master/Snakefile
ShowHide 24 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/ChiXX/isoform
Name: isoform
Version: 1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: None
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...