survivalist.testing
1# This program is free software: you can redistribute it and/or modify 2# it under the terms of the GNU General Public License as published by 3# the Free Software Foundation, either version 3 of the License, or 4# (at your option) any later version. 5# 6# This program is distributed in the hope that it will be useful, 7# but WITHOUT ANY WARRANTY; without even the implied warranty of 8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9# GNU General Public License for more details. 10# 11# You should have received a copy of the GNU General Public License 12# along with this program. If not, see <http://www.gnu.org/licenses/>. 13from importlib import import_module 14import inspect 15from pathlib import Path 16import pkgutil 17 18import numpy as np 19from numpy.testing import assert_almost_equal, assert_array_equal 20import pytest 21from sklearn.base import BaseEstimator, TransformerMixin 22 23import survivalist 24from survivalist.metrics import concordance_index_censored 25 26 27def assert_cindex_almost_equal(event_indicator, event_time, estimate, expected): 28 result = concordance_index_censored(event_indicator, event_time, estimate) 29 assert_array_equal(result[1:], expected[1:]) 30 concordant, discordant, tied_risk = result[1:4] 31 cc = (concordant + 0.5 * tied_risk) / (concordant + discordant + tied_risk) 32 assert_almost_equal(result[0], cc) 33 assert_almost_equal(result[0], expected[0]) 34 35 36def assert_survival_function_properties(surv_fns): 37 if not np.isfinite(surv_fns).all(): 38 raise AssertionError( 39 "survival function contains values that are not finite") 40 if np.any(surv_fns < 0.0): 41 raise AssertionError("survival function contains negative values") 42 if np.any(surv_fns > 1.0): 43 raise AssertionError("survival function contains values larger 1") 44 45 d = np.apply_along_axis(np.diff, 1, surv_fns) 46 if np.any(d > 0): 47 raise AssertionError( 48 "survival functions are not monotonically decreasing") 49 50 # survival function at first time point 51 num_closer_to_zero = np.sum(1.0 - surv_fns[:, 0] >= surv_fns[:, 0]) 52 if num_closer_to_zero / surv_fns.shape[0] > 0.5: 53 raise AssertionError( 54 f"most ({num_closer_to_zero}) probabilities at first time point are closer to 0 than 1") 55 56 # survival function at last time point 57 num_closer_to_one = np.sum(1.0 - surv_fns[:, -1] < surv_fns[:, -1]) 58 if num_closer_to_one / surv_fns.shape[0] > 0.5: 59 raise AssertionError( 60 f"most ({num_closer_to_one}) probabilities at last time point are closer to 1 than 0") 61 62 63def assert_chf_properties(chf): 64 if not np.isfinite(chf).all(): 65 raise AssertionError("chf contains values that are not finite") 66 if np.any(chf < 0.0): 67 raise AssertionError("chf contains negative values") 68 69 d = np.apply_along_axis(np.diff, 1, chf) 70 if np.any(d < 0): 71 raise AssertionError("chf are not monotonically increasing") 72 73 # chf at first time point 74 num_closer_to_one = np.sum(1.0 - chf[:, 0] < chf[:, 0]) 75 if num_closer_to_one / chf.shape[0] > 0.5: 76 raise AssertionError( 77 f"most ({num_closer_to_one}) hazard rates at first time point are closer to 1 than 0") 78 79 80def _is_survival_estimator(x): 81 return ( 82 inspect.isclass(x) 83 and issubclass(x, BaseEstimator) 84 and not issubclass(x, TransformerMixin) 85 and x.__module__.startswith("survivalist.") 86 and not x.__name__.startswith("_") 87 and x.__module__.split(".", 2)[1] not in {"metrics", "nonparametric"} 88 ) 89 90 91def all_survival_estimators(): 92 root = str(Path(survivalist.__file__).parent) 93 all_classes = [] 94 for _importer, modname, _ispkg in pkgutil.walk_packages(path=[root], prefix="survivalist."): 95 # meta-estimators require base estimators 96 if modname.startswith("survivalist.meta"): 97 continue 98 module = import_module(modname) 99 for _name, cls in inspect.getmembers(module, _is_survival_estimator): 100 if inspect.isabstract(cls): 101 continue 102 all_classes.append(cls) 103 return set(all_classes) 104 105 106class FixtureParameterFactory: 107 def get_cases(self): 108 cases = [] 109 for name, func in inspect.getmembers(self): 110 if name.startswith("data_"): 111 values = func() 112 cases.append(pytest.param(*values, id=name)) 113 return cases
def
assert_cindex_almost_equal(event_indicator, event_time, estimate, expected):
28def assert_cindex_almost_equal(event_indicator, event_time, estimate, expected): 29 result = concordance_index_censored(event_indicator, event_time, estimate) 30 assert_array_equal(result[1:], expected[1:]) 31 concordant, discordant, tied_risk = result[1:4] 32 cc = (concordant + 0.5 * tied_risk) / (concordant + discordant + tied_risk) 33 assert_almost_equal(result[0], cc) 34 assert_almost_equal(result[0], expected[0])
def
assert_survival_function_properties(surv_fns):
37def assert_survival_function_properties(surv_fns): 38 if not np.isfinite(surv_fns).all(): 39 raise AssertionError( 40 "survival function contains values that are not finite") 41 if np.any(surv_fns < 0.0): 42 raise AssertionError("survival function contains negative values") 43 if np.any(surv_fns > 1.0): 44 raise AssertionError("survival function contains values larger 1") 45 46 d = np.apply_along_axis(np.diff, 1, surv_fns) 47 if np.any(d > 0): 48 raise AssertionError( 49 "survival functions are not monotonically decreasing") 50 51 # survival function at first time point 52 num_closer_to_zero = np.sum(1.0 - surv_fns[:, 0] >= surv_fns[:, 0]) 53 if num_closer_to_zero / surv_fns.shape[0] > 0.5: 54 raise AssertionError( 55 f"most ({num_closer_to_zero}) probabilities at first time point are closer to 0 than 1") 56 57 # survival function at last time point 58 num_closer_to_one = np.sum(1.0 - surv_fns[:, -1] < surv_fns[:, -1]) 59 if num_closer_to_one / surv_fns.shape[0] > 0.5: 60 raise AssertionError( 61 f"most ({num_closer_to_one}) probabilities at last time point are closer to 1 than 0")
def
assert_chf_properties(chf):
64def assert_chf_properties(chf): 65 if not np.isfinite(chf).all(): 66 raise AssertionError("chf contains values that are not finite") 67 if np.any(chf < 0.0): 68 raise AssertionError("chf contains negative values") 69 70 d = np.apply_along_axis(np.diff, 1, chf) 71 if np.any(d < 0): 72 raise AssertionError("chf are not monotonically increasing") 73 74 # chf at first time point 75 num_closer_to_one = np.sum(1.0 - chf[:, 0] < chf[:, 0]) 76 if num_closer_to_one / chf.shape[0] > 0.5: 77 raise AssertionError( 78 f"most ({num_closer_to_one}) hazard rates at first time point are closer to 1 than 0")
def
all_survival_estimators():
92def all_survival_estimators(): 93 root = str(Path(survivalist.__file__).parent) 94 all_classes = [] 95 for _importer, modname, _ispkg in pkgutil.walk_packages(path=[root], prefix="survivalist."): 96 # meta-estimators require base estimators 97 if modname.startswith("survivalist.meta"): 98 continue 99 module = import_module(modname) 100 for _name, cls in inspect.getmembers(module, _is_survival_estimator): 101 if inspect.isabstract(cls): 102 continue 103 all_classes.append(cls) 104 return set(all_classes)
class
FixtureParameterFactory: