survivalist.testing

  1# This program is free software: you can redistribute it and/or modify
  2# it under the terms of the GNU General Public License as published by
  3# the Free Software Foundation, either version 3 of the License, or
  4# (at your option) any later version.
  5#
  6# This program is distributed in the hope that it will be useful,
  7# but WITHOUT ANY WARRANTY; without even the implied warranty of
  8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  9# GNU General Public License for more details.
 10#
 11# You should have received a copy of the GNU General Public License
 12# along with this program.  If not, see <http://www.gnu.org/licenses/>.
 13from importlib import import_module
 14import inspect
 15from pathlib import Path
 16import pkgutil
 17
 18import numpy as np
 19from numpy.testing import assert_almost_equal, assert_array_equal
 20import pytest
 21from sklearn.base import BaseEstimator, TransformerMixin
 22
 23import survivalist
 24from survivalist.metrics import concordance_index_censored
 25
 26
 27def assert_cindex_almost_equal(event_indicator, event_time, estimate, expected):
 28    result = concordance_index_censored(event_indicator, event_time, estimate)
 29    assert_array_equal(result[1:], expected[1:])
 30    concordant, discordant, tied_risk = result[1:4]
 31    cc = (concordant + 0.5 * tied_risk) / (concordant + discordant + tied_risk)
 32    assert_almost_equal(result[0], cc)
 33    assert_almost_equal(result[0], expected[0])
 34
 35
 36def assert_survival_function_properties(surv_fns):
 37    if not np.isfinite(surv_fns).all():
 38        raise AssertionError(
 39            "survival function contains values that are not finite")
 40    if np.any(surv_fns < 0.0):
 41        raise AssertionError("survival function contains negative values")
 42    if np.any(surv_fns > 1.0):
 43        raise AssertionError("survival function contains values larger 1")
 44
 45    d = np.apply_along_axis(np.diff, 1, surv_fns)
 46    if np.any(d > 0):
 47        raise AssertionError(
 48            "survival functions are not monotonically decreasing")
 49
 50    # survival function at first time point
 51    num_closer_to_zero = np.sum(1.0 - surv_fns[:, 0] >= surv_fns[:, 0])
 52    if num_closer_to_zero / surv_fns.shape[0] > 0.5:
 53        raise AssertionError(
 54            f"most ({num_closer_to_zero}) probabilities at first time point are closer to 0 than 1")
 55
 56    # survival function at last time point
 57    num_closer_to_one = np.sum(1.0 - surv_fns[:, -1] < surv_fns[:, -1])
 58    if num_closer_to_one / surv_fns.shape[0] > 0.5:
 59        raise AssertionError(
 60            f"most ({num_closer_to_one}) probabilities at last time point are closer to 1 than 0")
 61
 62
 63def assert_chf_properties(chf):
 64    if not np.isfinite(chf).all():
 65        raise AssertionError("chf contains values that are not finite")
 66    if np.any(chf < 0.0):
 67        raise AssertionError("chf contains negative values")
 68
 69    d = np.apply_along_axis(np.diff, 1, chf)
 70    if np.any(d < 0):
 71        raise AssertionError("chf are not monotonically increasing")
 72
 73    # chf at first time point
 74    num_closer_to_one = np.sum(1.0 - chf[:, 0] < chf[:, 0])
 75    if num_closer_to_one / chf.shape[0] > 0.5:
 76        raise AssertionError(
 77            f"most ({num_closer_to_one}) hazard rates at first time point are closer to 1 than 0")
 78
 79
 80def _is_survival_estimator(x):
 81    return (
 82        inspect.isclass(x)
 83        and issubclass(x, BaseEstimator)
 84        and not issubclass(x, TransformerMixin)
 85        and x.__module__.startswith("survivalist.")
 86        and not x.__name__.startswith("_")
 87        and x.__module__.split(".", 2)[1] not in {"metrics", "nonparametric"}
 88    )
 89
 90
 91def all_survival_estimators():
 92    root = str(Path(survivalist.__file__).parent)
 93    all_classes = []
 94    for _importer, modname, _ispkg in pkgutil.walk_packages(path=[root], prefix="survivalist."):
 95        # meta-estimators require base estimators
 96        if modname.startswith("survivalist.meta"):
 97            continue
 98        module = import_module(modname)
 99        for _name, cls in inspect.getmembers(module, _is_survival_estimator):
100            if inspect.isabstract(cls):
101                continue
102            all_classes.append(cls)
103    return set(all_classes)
104
105
106class FixtureParameterFactory:
107    def get_cases(self):
108        cases = []
109        for name, func in inspect.getmembers(self):
110            if name.startswith("data_"):
111                values = func()
112                cases.append(pytest.param(*values, id=name))
113        return cases
def assert_cindex_almost_equal(event_indicator, event_time, estimate, expected):
28def assert_cindex_almost_equal(event_indicator, event_time, estimate, expected):
29    result = concordance_index_censored(event_indicator, event_time, estimate)
30    assert_array_equal(result[1:], expected[1:])
31    concordant, discordant, tied_risk = result[1:4]
32    cc = (concordant + 0.5 * tied_risk) / (concordant + discordant + tied_risk)
33    assert_almost_equal(result[0], cc)
34    assert_almost_equal(result[0], expected[0])
def assert_survival_function_properties(surv_fns):
37def assert_survival_function_properties(surv_fns):
38    if not np.isfinite(surv_fns).all():
39        raise AssertionError(
40            "survival function contains values that are not finite")
41    if np.any(surv_fns < 0.0):
42        raise AssertionError("survival function contains negative values")
43    if np.any(surv_fns > 1.0):
44        raise AssertionError("survival function contains values larger 1")
45
46    d = np.apply_along_axis(np.diff, 1, surv_fns)
47    if np.any(d > 0):
48        raise AssertionError(
49            "survival functions are not monotonically decreasing")
50
51    # survival function at first time point
52    num_closer_to_zero = np.sum(1.0 - surv_fns[:, 0] >= surv_fns[:, 0])
53    if num_closer_to_zero / surv_fns.shape[0] > 0.5:
54        raise AssertionError(
55            f"most ({num_closer_to_zero}) probabilities at first time point are closer to 0 than 1")
56
57    # survival function at last time point
58    num_closer_to_one = np.sum(1.0 - surv_fns[:, -1] < surv_fns[:, -1])
59    if num_closer_to_one / surv_fns.shape[0] > 0.5:
60        raise AssertionError(
61            f"most ({num_closer_to_one}) probabilities at last time point are closer to 1 than 0")
def assert_chf_properties(chf):
64def assert_chf_properties(chf):
65    if not np.isfinite(chf).all():
66        raise AssertionError("chf contains values that are not finite")
67    if np.any(chf < 0.0):
68        raise AssertionError("chf contains negative values")
69
70    d = np.apply_along_axis(np.diff, 1, chf)
71    if np.any(d < 0):
72        raise AssertionError("chf are not monotonically increasing")
73
74    # chf at first time point
75    num_closer_to_one = np.sum(1.0 - chf[:, 0] < chf[:, 0])
76    if num_closer_to_one / chf.shape[0] > 0.5:
77        raise AssertionError(
78            f"most ({num_closer_to_one}) hazard rates at first time point are closer to 1 than 0")
def all_survival_estimators():
 92def all_survival_estimators():
 93    root = str(Path(survivalist.__file__).parent)
 94    all_classes = []
 95    for _importer, modname, _ispkg in pkgutil.walk_packages(path=[root], prefix="survivalist."):
 96        # meta-estimators require base estimators
 97        if modname.startswith("survivalist.meta"):
 98            continue
 99        module = import_module(modname)
100        for _name, cls in inspect.getmembers(module, _is_survival_estimator):
101            if inspect.isabstract(cls):
102                continue
103            all_classes.append(cls)
104    return set(all_classes)
class FixtureParameterFactory:
107class FixtureParameterFactory:
108    def get_cases(self):
109        cases = []
110        for name, func in inspect.getmembers(self):
111            if name.startswith("data_"):
112                values = func()
113                cases.append(pytest.param(*values, id=name))
114        return cases