survivalist.survstack.survstacker

  1import numpy as np
  2from sklearn.calibration import CalibratedClassifierCV
  3from sklearn.ensemble import RandomForestClassifier
  4from .transformer import SurvivalStacker
  5from ..util import check_array_survival
  6from ..base import SurvivalAnalysisMixin
  7from ..linear_model.coxph import BreslowEstimator
  8from ..ensemble.survival_loss import (
  9    LOSS_FUNCTIONS,
 10    CoxPH,
 11)
 12from ..functions import StepFunction
 13
 14
 15class SurvStacker(SurvivalAnalysisMixin):
 16    """
 17    A class to create a Survival Stacker for any classifier.
 18    """
 19
 20    def __init__(self, clf=RandomForestClassifier(), loss="squared", random_state=42, **kwargs):
 21        """
 22        Parameters
 23        ----------
 24        clf : classifier, default: RandomForestClassifier()
 25            The classifier to be used for stacking.
 26
 27        loss : {'coxph', 'squared', 'ipcwls'}, optional, default: 'squared'
 28            Loss function to be optimized.
 29
 30        random_state : int seed, RandomState instance, or None, default: 42
 31            The seed of the pseudo random number generator.
 32
 33        kwargs : additional parameters to be passed to CalibratedClassifierCV
 34        """
 35        self.random_state = random_state
 36        self.clf = clf
 37        try:
 38            self.clf.set_params(random_state=self.random_state)
 39        except Exception:
 40            pass
 41        self.clf = CalibratedClassifierCV(clf, cv=3, **kwargs)
 42        self.ss = SurvivalStacker()
 43        self._baseline_model = None
 44        self.loss = loss
 45        self._loss = LOSS_FUNCTIONS[self.loss]()
 46
 47        if self.loss not in ["coxph", "squared", "ipcwls"]:
 48            raise ValueError(
 49                f"Invalid loss value: {self.loss}. Choose from 'coxph', 'squared', or 'ipcwls'.")
 50
 51        self.times_ = None
 52        self.unique_times_ = None
 53
 54    def _get_baseline_model(self):
 55        """Get the baseline model for the survival stacker."""
 56        return self._baseline_model
 57
 58    def _set_baseline_model(self, X, event, time):
 59        if isinstance(self._loss, CoxPH):
 60            risk_scores = self.predict(X)
 61            self._baseline_model = BreslowEstimator().fit(risk_scores, event, time)
 62        else:
 63            self._baseline_model = None
 64
 65    def fit(self, X, y, **kwargs):
 66        """
 67        Fit the Survival Stacker to the data.
 68
 69        Parameters
 70        ----------
 71        X : array-like, shape (n_samples, n_features)
 72            The input samples.
 73
 74        y : array-like, shape (n_samples,)
 75            The target values (survival times).
 76
 77        kwargs : additional parameters to be passed to the fitting function
 78
 79        Returns
 80        -------
 81        self : object
 82            Returns self.
 83        """
 84        if hasattr(X, "to_numpy"):
 85            X = X.to_numpy()
 86
 87        # Get survival stacker predictions
 88        X_oo, y_oo = self.ss.fit_transform(X, y)
 89        self.times_ = self.ss.times
 90        self.unique_times_ = np.sort(np.unique(self.ss.times))
 91
 92        # Fit classifier
 93        self.clf.fit(X_oo, y_oo, **kwargs)
 94
 95        # Set baseline model
 96        event, time = check_array_survival(X, y)
 97        self._set_baseline_model(X, event, time)
 98
 99        return self
100
101    def _predict_survival_function(self, X):
102        """
103        Predict survival function.
104        """
105        X_risk, _ = self.ss.transform(X)
106        oo_test_estimates = self.clf.predict_proba(X_risk)[:, 1]
107        return self.ss.predict_survival_function(oo_test_estimates)
108
109    def predict(self, X, threshold=0.5):
110        """
111        Predict survival times using a threshold.
112        """
113        surv = self._predict_survival_function(X)
114
115        crossings = surv <= threshold
116        cross_indices = np.argmax(crossings, axis=1)
117        valid_crossings = crossings[np.arange(len(crossings)), cross_indices]
118
119        predicted_times = np.where(
120            valid_crossings,
121            self.unique_times_[cross_indices],
122            self.unique_times_[-1],
123        )
124        return predicted_times
125
126    def predict_cumulative_hazard_function(self, X, return_array=False):
127        """
128        Predict cumulative hazard function.
129        """
130        return self._predict_cumulative_hazard_function(self._get_baseline_model(), self.predict(X), return_array)
131
132    def predict_survival_function(self, X, return_array=False):
133        """
134        Predict survival function.
135
136        Parameters
137        ----------
138        X : array-like, shape (n_samples, n_features)
139            The input samples.
140        return_array : bool, default=False
141            Whether to return the survival function as an array.
142
143        Returns
144        -------
145        array-like or list of StepFunction
146            Predicted survival function for each sample.
147        """
148        if hasattr(X, "to_numpy"):
149            X = X.to_numpy()
150
151        surv = self._predict_survival_function(X)
152
153        if return_array:
154            return surv
155
156        funcs = []
157        surv = np.asarray(surv)
158        if surv.ndim == 1:
159            surv = surv.reshape(1, -1)
160
161        for i in range(surv.shape[0]):
162            if len(self.unique_times_) != len(surv[i]):
163                x_old = np.linspace(0, 1, len(surv[i]))
164                x_new = np.linspace(0, 1, len(self.unique_times_))
165                surv_interp = np.interp(x_new, x_old, surv[i])
166            else:
167                surv_interp = surv[i]
168            func = StepFunction(x=self.unique_times_, y=surv_interp)
169            funcs.append(func)
170
171        return np.array(funcs)
class SurvStacker(survivalist.base.SurvivalAnalysisMixin):
 16class SurvStacker(SurvivalAnalysisMixin):
 17    """
 18    A class to create a Survival Stacker for any classifier.
 19    """
 20
 21    def __init__(self, clf=RandomForestClassifier(), loss="squared", random_state=42, **kwargs):
 22        """
 23        Parameters
 24        ----------
 25        clf : classifier, default: RandomForestClassifier()
 26            The classifier to be used for stacking.
 27
 28        loss : {'coxph', 'squared', 'ipcwls'}, optional, default: 'squared'
 29            Loss function to be optimized.
 30
 31        random_state : int seed, RandomState instance, or None, default: 42
 32            The seed of the pseudo random number generator.
 33
 34        kwargs : additional parameters to be passed to CalibratedClassifierCV
 35        """
 36        self.random_state = random_state
 37        self.clf = clf
 38        try:
 39            self.clf.set_params(random_state=self.random_state)
 40        except Exception:
 41            pass
 42        self.clf = CalibratedClassifierCV(clf, cv=3, **kwargs)
 43        self.ss = SurvivalStacker()
 44        self._baseline_model = None
 45        self.loss = loss
 46        self._loss = LOSS_FUNCTIONS[self.loss]()
 47
 48        if self.loss not in ["coxph", "squared", "ipcwls"]:
 49            raise ValueError(
 50                f"Invalid loss value: {self.loss}. Choose from 'coxph', 'squared', or 'ipcwls'.")
 51
 52        self.times_ = None
 53        self.unique_times_ = None
 54
 55    def _get_baseline_model(self):
 56        """Get the baseline model for the survival stacker."""
 57        return self._baseline_model
 58
 59    def _set_baseline_model(self, X, event, time):
 60        if isinstance(self._loss, CoxPH):
 61            risk_scores = self.predict(X)
 62            self._baseline_model = BreslowEstimator().fit(risk_scores, event, time)
 63        else:
 64            self._baseline_model = None
 65
 66    def fit(self, X, y, **kwargs):
 67        """
 68        Fit the Survival Stacker to the data.
 69
 70        Parameters
 71        ----------
 72        X : array-like, shape (n_samples, n_features)
 73            The input samples.
 74
 75        y : array-like, shape (n_samples,)
 76            The target values (survival times).
 77
 78        kwargs : additional parameters to be passed to the fitting function
 79
 80        Returns
 81        -------
 82        self : object
 83            Returns self.
 84        """
 85        if hasattr(X, "to_numpy"):
 86            X = X.to_numpy()
 87
 88        # Get survival stacker predictions
 89        X_oo, y_oo = self.ss.fit_transform(X, y)
 90        self.times_ = self.ss.times
 91        self.unique_times_ = np.sort(np.unique(self.ss.times))
 92
 93        # Fit classifier
 94        self.clf.fit(X_oo, y_oo, **kwargs)
 95
 96        # Set baseline model
 97        event, time = check_array_survival(X, y)
 98        self._set_baseline_model(X, event, time)
 99
100        return self
101
102    def _predict_survival_function(self, X):
103        """
104        Predict survival function.
105        """
106        X_risk, _ = self.ss.transform(X)
107        oo_test_estimates = self.clf.predict_proba(X_risk)[:, 1]
108        return self.ss.predict_survival_function(oo_test_estimates)
109
110    def predict(self, X, threshold=0.5):
111        """
112        Predict survival times using a threshold.
113        """
114        surv = self._predict_survival_function(X)
115
116        crossings = surv <= threshold
117        cross_indices = np.argmax(crossings, axis=1)
118        valid_crossings = crossings[np.arange(len(crossings)), cross_indices]
119
120        predicted_times = np.where(
121            valid_crossings,
122            self.unique_times_[cross_indices],
123            self.unique_times_[-1],
124        )
125        return predicted_times
126
127    def predict_cumulative_hazard_function(self, X, return_array=False):
128        """
129        Predict cumulative hazard function.
130        """
131        return self._predict_cumulative_hazard_function(self._get_baseline_model(), self.predict(X), return_array)
132
133    def predict_survival_function(self, X, return_array=False):
134        """
135        Predict survival function.
136
137        Parameters
138        ----------
139        X : array-like, shape (n_samples, n_features)
140            The input samples.
141        return_array : bool, default=False
142            Whether to return the survival function as an array.
143
144        Returns
145        -------
146        array-like or list of StepFunction
147            Predicted survival function for each sample.
148        """
149        if hasattr(X, "to_numpy"):
150            X = X.to_numpy()
151
152        surv = self._predict_survival_function(X)
153
154        if return_array:
155            return surv
156
157        funcs = []
158        surv = np.asarray(surv)
159        if surv.ndim == 1:
160            surv = surv.reshape(1, -1)
161
162        for i in range(surv.shape[0]):
163            if len(self.unique_times_) != len(surv[i]):
164                x_old = np.linspace(0, 1, len(surv[i]))
165                x_new = np.linspace(0, 1, len(self.unique_times_))
166                surv_interp = np.interp(x_new, x_old, surv[i])
167            else:
168                surv_interp = surv[i]
169            func = StepFunction(x=self.unique_times_, y=surv_interp)
170            funcs.append(func)
171
172        return np.array(funcs)

A class to create a Survival Stacker for any classifier.

def fit(self, X, y, **kwargs):
 66    def fit(self, X, y, **kwargs):
 67        """
 68        Fit the Survival Stacker to the data.
 69
 70        Parameters
 71        ----------
 72        X : array-like, shape (n_samples, n_features)
 73            The input samples.
 74
 75        y : array-like, shape (n_samples,)
 76            The target values (survival times).
 77
 78        kwargs : additional parameters to be passed to the fitting function
 79
 80        Returns
 81        -------
 82        self : object
 83            Returns self.
 84        """
 85        if hasattr(X, "to_numpy"):
 86            X = X.to_numpy()
 87
 88        # Get survival stacker predictions
 89        X_oo, y_oo = self.ss.fit_transform(X, y)
 90        self.times_ = self.ss.times
 91        self.unique_times_ = np.sort(np.unique(self.ss.times))
 92
 93        # Fit classifier
 94        self.clf.fit(X_oo, y_oo, **kwargs)
 95
 96        # Set baseline model
 97        event, time = check_array_survival(X, y)
 98        self._set_baseline_model(X, event, time)
 99
100        return self

Fit the Survival Stacker to the data.

Parameters

X : array-like, shape (n_samples, n_features) The input samples.

y : array-like, shape (n_samples,) The target values (survival times).

kwargs : additional parameters to be passed to the fitting function

Returns

self : object Returns self.

def predict(self, X, threshold=0.5):
110    def predict(self, X, threshold=0.5):
111        """
112        Predict survival times using a threshold.
113        """
114        surv = self._predict_survival_function(X)
115
116        crossings = surv <= threshold
117        cross_indices = np.argmax(crossings, axis=1)
118        valid_crossings = crossings[np.arange(len(crossings)), cross_indices]
119
120        predicted_times = np.where(
121            valid_crossings,
122            self.unique_times_[cross_indices],
123            self.unique_times_[-1],
124        )
125        return predicted_times

Predict survival times using a threshold.