survivalist.survstack.survstacker
1import numpy as np 2from sklearn.calibration import CalibratedClassifierCV 3from sklearn.ensemble import RandomForestClassifier 4from .transformer import SurvivalStacker 5from ..util import check_array_survival 6from ..base import SurvivalAnalysisMixin 7from ..linear_model.coxph import BreslowEstimator 8from ..ensemble.survival_loss import ( 9 LOSS_FUNCTIONS, 10 CoxPH, 11) 12from ..functions import StepFunction 13 14 15class SurvStacker(SurvivalAnalysisMixin): 16 """ 17 A class to create a Survival Stacker for any classifier. 18 """ 19 20 def __init__(self, clf=RandomForestClassifier(), loss="squared", random_state=42, **kwargs): 21 """ 22 Parameters 23 ---------- 24 clf : classifier, default: RandomForestClassifier() 25 The classifier to be used for stacking. 26 27 loss : {'coxph', 'squared', 'ipcwls'}, optional, default: 'squared' 28 Loss function to be optimized. 29 30 random_state : int seed, RandomState instance, or None, default: 42 31 The seed of the pseudo random number generator. 32 33 kwargs : additional parameters to be passed to CalibratedClassifierCV 34 """ 35 self.random_state = random_state 36 self.clf = clf 37 try: 38 self.clf.set_params(random_state=self.random_state) 39 except Exception: 40 pass 41 self.clf = CalibratedClassifierCV(clf, cv=3, **kwargs) 42 self.ss = SurvivalStacker() 43 self._baseline_model = None 44 self.loss = loss 45 self._loss = LOSS_FUNCTIONS[self.loss]() 46 47 if self.loss not in ["coxph", "squared", "ipcwls"]: 48 raise ValueError( 49 f"Invalid loss value: {self.loss}. Choose from 'coxph', 'squared', or 'ipcwls'.") 50 51 self.times_ = None 52 self.unique_times_ = None 53 54 def _get_baseline_model(self): 55 """Get the baseline model for the survival stacker.""" 56 return self._baseline_model 57 58 def _set_baseline_model(self, X, event, time): 59 if isinstance(self._loss, CoxPH): 60 risk_scores = self.predict(X) 61 self._baseline_model = BreslowEstimator().fit(risk_scores, event, time) 62 else: 63 self._baseline_model = None 64 65 def fit(self, X, y, **kwargs): 66 """ 67 Fit the Survival Stacker to the data. 68 69 Parameters 70 ---------- 71 X : array-like, shape (n_samples, n_features) 72 The input samples. 73 74 y : array-like, shape (n_samples,) 75 The target values (survival times). 76 77 kwargs : additional parameters to be passed to the fitting function 78 79 Returns 80 ------- 81 self : object 82 Returns self. 83 """ 84 if hasattr(X, "to_numpy"): 85 X = X.to_numpy() 86 87 # Get survival stacker predictions 88 X_oo, y_oo = self.ss.fit_transform(X, y) 89 self.times_ = self.ss.times 90 self.unique_times_ = np.sort(np.unique(self.ss.times)) 91 92 # Fit classifier 93 self.clf.fit(X_oo, y_oo, **kwargs) 94 95 # Set baseline model 96 event, time = check_array_survival(X, y) 97 self._set_baseline_model(X, event, time) 98 99 return self 100 101 def _predict_survival_function(self, X): 102 """ 103 Predict survival function. 104 """ 105 X_risk, _ = self.ss.transform(X) 106 oo_test_estimates = self.clf.predict_proba(X_risk)[:, 1] 107 return self.ss.predict_survival_function(oo_test_estimates) 108 109 def predict(self, X, threshold=0.5): 110 """ 111 Predict survival times using a threshold. 112 """ 113 surv = self._predict_survival_function(X) 114 115 crossings = surv <= threshold 116 cross_indices = np.argmax(crossings, axis=1) 117 valid_crossings = crossings[np.arange(len(crossings)), cross_indices] 118 119 predicted_times = np.where( 120 valid_crossings, 121 self.unique_times_[cross_indices], 122 self.unique_times_[-1], 123 ) 124 return predicted_times 125 126 def predict_cumulative_hazard_function(self, X, return_array=False): 127 """ 128 Predict cumulative hazard function. 129 """ 130 return self._predict_cumulative_hazard_function(self._get_baseline_model(), self.predict(X), return_array) 131 132 def predict_survival_function(self, X, return_array=False): 133 """ 134 Predict survival function. 135 136 Parameters 137 ---------- 138 X : array-like, shape (n_samples, n_features) 139 The input samples. 140 return_array : bool, default=False 141 Whether to return the survival function as an array. 142 143 Returns 144 ------- 145 array-like or list of StepFunction 146 Predicted survival function for each sample. 147 """ 148 if hasattr(X, "to_numpy"): 149 X = X.to_numpy() 150 151 surv = self._predict_survival_function(X) 152 153 if return_array: 154 return surv 155 156 funcs = [] 157 surv = np.asarray(surv) 158 if surv.ndim == 1: 159 surv = surv.reshape(1, -1) 160 161 for i in range(surv.shape[0]): 162 if len(self.unique_times_) != len(surv[i]): 163 x_old = np.linspace(0, 1, len(surv[i])) 164 x_new = np.linspace(0, 1, len(self.unique_times_)) 165 surv_interp = np.interp(x_new, x_old, surv[i]) 166 else: 167 surv_interp = surv[i] 168 func = StepFunction(x=self.unique_times_, y=surv_interp) 169 funcs.append(func) 170 171 return np.array(funcs)
16class SurvStacker(SurvivalAnalysisMixin): 17 """ 18 A class to create a Survival Stacker for any classifier. 19 """ 20 21 def __init__(self, clf=RandomForestClassifier(), loss="squared", random_state=42, **kwargs): 22 """ 23 Parameters 24 ---------- 25 clf : classifier, default: RandomForestClassifier() 26 The classifier to be used for stacking. 27 28 loss : {'coxph', 'squared', 'ipcwls'}, optional, default: 'squared' 29 Loss function to be optimized. 30 31 random_state : int seed, RandomState instance, or None, default: 42 32 The seed of the pseudo random number generator. 33 34 kwargs : additional parameters to be passed to CalibratedClassifierCV 35 """ 36 self.random_state = random_state 37 self.clf = clf 38 try: 39 self.clf.set_params(random_state=self.random_state) 40 except Exception: 41 pass 42 self.clf = CalibratedClassifierCV(clf, cv=3, **kwargs) 43 self.ss = SurvivalStacker() 44 self._baseline_model = None 45 self.loss = loss 46 self._loss = LOSS_FUNCTIONS[self.loss]() 47 48 if self.loss not in ["coxph", "squared", "ipcwls"]: 49 raise ValueError( 50 f"Invalid loss value: {self.loss}. Choose from 'coxph', 'squared', or 'ipcwls'.") 51 52 self.times_ = None 53 self.unique_times_ = None 54 55 def _get_baseline_model(self): 56 """Get the baseline model for the survival stacker.""" 57 return self._baseline_model 58 59 def _set_baseline_model(self, X, event, time): 60 if isinstance(self._loss, CoxPH): 61 risk_scores = self.predict(X) 62 self._baseline_model = BreslowEstimator().fit(risk_scores, event, time) 63 else: 64 self._baseline_model = None 65 66 def fit(self, X, y, **kwargs): 67 """ 68 Fit the Survival Stacker to the data. 69 70 Parameters 71 ---------- 72 X : array-like, shape (n_samples, n_features) 73 The input samples. 74 75 y : array-like, shape (n_samples,) 76 The target values (survival times). 77 78 kwargs : additional parameters to be passed to the fitting function 79 80 Returns 81 ------- 82 self : object 83 Returns self. 84 """ 85 if hasattr(X, "to_numpy"): 86 X = X.to_numpy() 87 88 # Get survival stacker predictions 89 X_oo, y_oo = self.ss.fit_transform(X, y) 90 self.times_ = self.ss.times 91 self.unique_times_ = np.sort(np.unique(self.ss.times)) 92 93 # Fit classifier 94 self.clf.fit(X_oo, y_oo, **kwargs) 95 96 # Set baseline model 97 event, time = check_array_survival(X, y) 98 self._set_baseline_model(X, event, time) 99 100 return self 101 102 def _predict_survival_function(self, X): 103 """ 104 Predict survival function. 105 """ 106 X_risk, _ = self.ss.transform(X) 107 oo_test_estimates = self.clf.predict_proba(X_risk)[:, 1] 108 return self.ss.predict_survival_function(oo_test_estimates) 109 110 def predict(self, X, threshold=0.5): 111 """ 112 Predict survival times using a threshold. 113 """ 114 surv = self._predict_survival_function(X) 115 116 crossings = surv <= threshold 117 cross_indices = np.argmax(crossings, axis=1) 118 valid_crossings = crossings[np.arange(len(crossings)), cross_indices] 119 120 predicted_times = np.where( 121 valid_crossings, 122 self.unique_times_[cross_indices], 123 self.unique_times_[-1], 124 ) 125 return predicted_times 126 127 def predict_cumulative_hazard_function(self, X, return_array=False): 128 """ 129 Predict cumulative hazard function. 130 """ 131 return self._predict_cumulative_hazard_function(self._get_baseline_model(), self.predict(X), return_array) 132 133 def predict_survival_function(self, X, return_array=False): 134 """ 135 Predict survival function. 136 137 Parameters 138 ---------- 139 X : array-like, shape (n_samples, n_features) 140 The input samples. 141 return_array : bool, default=False 142 Whether to return the survival function as an array. 143 144 Returns 145 ------- 146 array-like or list of StepFunction 147 Predicted survival function for each sample. 148 """ 149 if hasattr(X, "to_numpy"): 150 X = X.to_numpy() 151 152 surv = self._predict_survival_function(X) 153 154 if return_array: 155 return surv 156 157 funcs = [] 158 surv = np.asarray(surv) 159 if surv.ndim == 1: 160 surv = surv.reshape(1, -1) 161 162 for i in range(surv.shape[0]): 163 if len(self.unique_times_) != len(surv[i]): 164 x_old = np.linspace(0, 1, len(surv[i])) 165 x_new = np.linspace(0, 1, len(self.unique_times_)) 166 surv_interp = np.interp(x_new, x_old, surv[i]) 167 else: 168 surv_interp = surv[i] 169 func = StepFunction(x=self.unique_times_, y=surv_interp) 170 funcs.append(func) 171 172 return np.array(funcs)
A class to create a Survival Stacker for any classifier.
def
fit(self, X, y, **kwargs):
66 def fit(self, X, y, **kwargs): 67 """ 68 Fit the Survival Stacker to the data. 69 70 Parameters 71 ---------- 72 X : array-like, shape (n_samples, n_features) 73 The input samples. 74 75 y : array-like, shape (n_samples,) 76 The target values (survival times). 77 78 kwargs : additional parameters to be passed to the fitting function 79 80 Returns 81 ------- 82 self : object 83 Returns self. 84 """ 85 if hasattr(X, "to_numpy"): 86 X = X.to_numpy() 87 88 # Get survival stacker predictions 89 X_oo, y_oo = self.ss.fit_transform(X, y) 90 self.times_ = self.ss.times 91 self.unique_times_ = np.sort(np.unique(self.ss.times)) 92 93 # Fit classifier 94 self.clf.fit(X_oo, y_oo, **kwargs) 95 96 # Set baseline model 97 event, time = check_array_survival(X, y) 98 self._set_baseline_model(X, event, time) 99 100 return self
Fit the Survival Stacker to the data.
Parameters
X : array-like, shape (n_samples, n_features) The input samples.
y : array-like, shape (n_samples,) The target values (survival times).
kwargs : additional parameters to be passed to the fitting function
Returns
self : object Returns self.
def
predict(self, X, threshold=0.5):
110 def predict(self, X, threshold=0.5): 111 """ 112 Predict survival times using a threshold. 113 """ 114 surv = self._predict_survival_function(X) 115 116 crossings = surv <= threshold 117 cross_indices = np.argmax(crossings, axis=1) 118 valid_crossings = crossings[np.arange(len(crossings)), cross_indices] 119 120 predicted_times = np.where( 121 valid_crossings, 122 self.unique_times_[cross_indices], 123 self.unique_times_[-1], 124 ) 125 return predicted_times
Predict survival times using a threshold.