synthe

Top-level package for synthe.

 1"""Top-level package for synthe."""
 2
 3__author__ = """T. Moudiki"""
 4__email__ = "thierry.moudiki@gmail.com"
 5
 6from .adaptivehistsampler import AdaptiveHistogramSampler  # noqa: F401
 7from .distro_simulator import DistroSimulator  # noqa: F401
 8from .empirical_copula import EmpiricalCopula  # noqa: F401
 9from .stratified_sampling import StratifiedClusteringSubsampling
10from .row_subsampling import SubSampler
11from .healthsims import SmartHealthSimulator  # noqa: F401
12from .metrics import DistanceMetrics  # noqa: F401
13from .meboot import MaximumEntropyBootstrap
14from .ts_distro_simulator import TsDistroSimulator  # noqa: F401
15from .diversity_generator import DiversityGenerator  # noqa: F401
16
17__all__ = [
18    "AdaptiveHistogramSampler",
19    "DistroSimulator",
20    "EmpiricalCopula",
21    "StratifiedClusteringSubsampling",
22    "SubSampler",
23    "SmartHealthSimulator",
24    "DistanceMetrics",
25    "MaximumEntropyBootstrap",
26    "TsDistroSimulator",
27    "DiversityGenerator",
28]
class AdaptiveHistogramSampler:
  7class AdaptiveHistogramSampler:
  8    def __init__(self, n_bins=10, method="quantile", seed=123):
  9        self.n_bins = n_bins
 10        self.method = method
 11        self.rng = np.random.default_rng(seed)
 12        self.bin_edges = None
 13        self.bin_indices = None
 14        self.unique_bins = None
 15        self.bin_probs = None
 16        self.X = None
 17        self.n = None
 18        self.d = None
 19
 20    def fit(self, X):
 21        self.X = np.asarray(X)
 22        self.n, self.d = self.X.shape
 23
 24        self.bin_edges = []
 25        for j in range(self.d):
 26            xj = self.X[:, j]
 27            if self.method == "quantile":
 28                edges_j = np.quantile(xj, np.linspace(0, 1, self.n_bins + 1))
 29            else:
 30                edges_j = np.linspace(xj.min(), xj.max(), self.n_bins + 1)
 31            self.bin_edges.append(edges_j)
 32
 33        # Assign points to bins
 34        bin_idx = np.zeros((self.n, self.d), dtype=int)
 35        for j in range(self.d):
 36            bin_idx[:, j] = np.digitize(self.X[:, j], self.bin_edges[j]) - 1
 37            bin_idx[:, j] = np.clip(bin_idx[:, j], 0, self.n_bins - 1)
 38        self.bin_indices = bin_idx
 39
 40        bin_ids = np.ravel_multi_index(
 41            self.bin_indices.T, (self.n_bins,) * self.d
 42        )
 43        unique_bins, counts = np.unique(bin_ids, return_counts=True)
 44        self.unique_bins = unique_bins
 45        self.bin_probs = counts / counts.sum()
 46
 47    def sample(
 48        self,
 49        n_samples,
 50        oversample=False,
 51        oversample_method="bootstrap",
 52        jitter_scale=0.05,
 53    ):
 54        if self.bin_probs is None:
 55            raise RuntimeError("You must call `fit` before `sample`.")
 56
 57        chosen_bins = self.rng.choice(
 58            self.unique_bins, size=n_samples, p=self.bin_probs
 59        )
 60
 61        if not oversample:
 62            return self._subsample_existing(chosen_bins)
 63
 64        if oversample_method == "uniform":
 65            return self._oversample_uniform(chosen_bins)
 66        elif oversample_method == "bootstrap":
 67            return self._oversample_bootstrap(chosen_bins)
 68        elif oversample_method == "jitter":
 69            return self._oversample_jitter(chosen_bins, jitter_scale)
 70        else:
 71            raise ValueError(f"Unknown oversample_method: {oversample_method}")
 72
 73    # --- Internal helpers ------------------------------------------------
 74    def _subsample_existing(self, chosen_bins):
 75        bin_ids = np.ravel_multi_index(
 76            self.bin_indices.T, (self.n_bins,) * self.d
 77        )
 78        X_sampled = []
 79        for b in chosen_bins:
 80            idx_in_bin = np.where(bin_ids == b)[0]
 81            i = self.rng.choice(idx_in_bin)
 82            X_sampled.append(self.X[i])
 83        return np.array(X_sampled)
 84
 85    def _oversample_uniform(self, chosen_bins):
 86        X_sampled = []
 87        for b in chosen_bins:
 88            multi_idx = np.unravel_index(b, (self.n_bins,) * self.d)
 89            coords = []
 90            for j, bi in enumerate(multi_idx):
 91                left = self.bin_edges[j][bi]
 92                right = self.bin_edges[j][bi + 1]
 93                coords.append(self.rng.uniform(left, right))
 94            X_sampled.append(coords)
 95        return np.array(X_sampled)
 96
 97    def _oversample_bootstrap(self, chosen_bins):
 98        return self._subsample_existing(chosen_bins)
 99
100    def _oversample_jitter(self, chosen_bins, jitter_scale):
101        base_points = self._subsample_existing(chosen_bins)
102        noise = self.rng.normal(scale=jitter_scale, size=base_points.shape)
103        return base_points + noise
104
105    # --- Visualization ----------------------------------------------------
106    def plot_comparison(self, X_sampled, bins=30):
107        """Plot joint 2D histogram and marginal distributions (for d=2)."""
108        if self.d != 2:
109            raise ValueError("plot_comparison only supports 2D currently.")
110
111        fig = plt.figure(figsize=(10, 10))
112        grid = plt.GridSpec(4, 4, hspace=0.3, wspace=0.3)
113
114        main_ax = fig.add_subplot(grid[1:, :-1])
115        y_hist = fig.add_subplot(grid[0, :-1], sharex=main_ax)
116        x_hist = fig.add_subplot(grid[1:, -1], sharey=main_ax)
117
118        # 2D histogram
119        main_ax.hist2d(
120            self.X[:, 0], self.X[:, 1], bins=bins, alpha=0.5, cmap="Blues"
121        )
122        main_ax.hist2d(
123            X_sampled[:, 0], X_sampled[:, 1], bins=bins, alpha=0.5, cmap="Reds"
124        )
125        main_ax.set_xlabel("X1")
126        main_ax.set_ylabel("X2")
127        main_ax.set_title("Joint distribution")
128
129        # Marginals
130        y_hist.hist(
131            self.X[:, 0], bins=bins, color="blue", alpha=0.5, density=True
132        )
133        y_hist.hist(
134            X_sampled[:, 0], bins=bins, color="red", alpha=0.5, density=True
135        )
136        x_hist.hist(
137            self.X[:, 1],
138            bins=bins,
139            orientation="horizontal",
140            color="blue",
141            alpha=0.5,
142            density=True,
143        )
144        x_hist.hist(
145            X_sampled[:, 1],
146            bins=bins,
147            orientation="horizontal",
148            color="red",
149            alpha=0.5,
150            density=True,
151        )
152
153        y_hist.axis("off")
154        x_hist.axis("off")
155
156        plt.show()
157
158    # --- Goodness of fit tests -------------------------------------------
159    def goodness_of_fit(self, X_sampled):
160        """
161        Compare marginals with Kolmogorov–Smirnov and Anderson–Darling tests.
162        Returns dict of test results for each dimension.
163        """
164        results = {}
165        for j in range(self.d):
166            x_orig = self.X[:, j]
167            x_samp = X_sampled[:, j]
168
169            # KS test
170            ks_stat, ks_p = stats.ks_2samp(x_orig, x_samp)
171
172            # Anderson-Darling test
173            ad_result = stats.anderson_ksamp([x_orig, x_samp])
174
175            results[f"dim_{j}"] = {
176                "ks_statistic": ks_stat,
177                "ks_pvalue": ks_p,
178                "ad_statistic": ad_result.statistic,
179                "ad_significance_level": ad_result.significance_level,
180            }
181        return results
def fit(self, X):
20    def fit(self, X):
21        self.X = np.asarray(X)
22        self.n, self.d = self.X.shape
23
24        self.bin_edges = []
25        for j in range(self.d):
26            xj = self.X[:, j]
27            if self.method == "quantile":
28                edges_j = np.quantile(xj, np.linspace(0, 1, self.n_bins + 1))
29            else:
30                edges_j = np.linspace(xj.min(), xj.max(), self.n_bins + 1)
31            self.bin_edges.append(edges_j)
32
33        # Assign points to bins
34        bin_idx = np.zeros((self.n, self.d), dtype=int)
35        for j in range(self.d):
36            bin_idx[:, j] = np.digitize(self.X[:, j], self.bin_edges[j]) - 1
37            bin_idx[:, j] = np.clip(bin_idx[:, j], 0, self.n_bins - 1)
38        self.bin_indices = bin_idx
39
40        bin_ids = np.ravel_multi_index(
41            self.bin_indices.T, (self.n_bins,) * self.d
42        )
43        unique_bins, counts = np.unique(bin_ids, return_counts=True)
44        self.unique_bins = unique_bins
45        self.bin_probs = counts / counts.sum()
class DistroSimulator:
  36class DistroSimulator:
  37    def __init__(
  38        self,
  39        kernel="rbf",
  40        backend="numpy",
  41        n_clusters=5,
  42        clustering_method="kmeans",
  43        kde_kernel="gaussian",
  44        random_state=None,
  45        conformalize=False,
  46        residual_sampling="bootstrap",
  47        block_size=None,
  48        gmm_components=3,
  49        use_rff="auto",
  50        rff_components="auto",
  51        rff_gamma=None,
  52        kernel_approximation="rff",
  53        force_rff_threshold=1000,
  54    ):
  55        """
  56        Initialize the multivariate data generator.
  57
  58        Parameters:
  59        -----------
  60        kernel : str, default='rbf'
  61            Kernel type for KernelRidge regression
  62        backend : str, default='numpy'
  63            Backend for distance calculations ('numpy', 'gpu', 'tpu')
  64        n_clusters : int, default=5
  65            Number of clusters for stratified splitting
  66        clustering_method : str, default='kmeans'
  67            Clustering method for stratification ('kmeans' or 'gmm')
  68        random_state : int, default=None
  69            Random seed for reproducibility
  70        conformalize : bool
  71            Use split conformal prediction or not
  72        residual_sampling : str, default='bootstrap'
  73            Method for sampling residuals ('bootstrap', 'kde', 'gmm', 'block-bootstrap', 'me-bootstrap').
  74            Where 'me-bootstrap' refers to Maximum Entropy Bootstrap.
  75        block_size : int, default=None
  76            Block size for block bootstrap (if applicable)
  77        gmm_components : int, default=3
  78            Number of components for GMM sampling
  79        use_rff : bool or 'auto', default='auto'
  80            Whether to use kernel approximation. 'auto' enables for large datasets
  81        rff_components : int or 'auto', default='auto'
  82            Number of approximation components. 'auto' chooses based on data size
  83        rff_gamma : float, default=None
  84            Gamma parameter for approximation. If None, will be tuned.
  85        kernel_approximation : str, default='rff'
  86            Approximation method ('rff' or 'nystroem')
  87        force_rff_threshold : int, default=1000
  88            Auto-enable RFF when n_samples exceeds this threshold
  89        """
  90        self.kernel = kernel
  91        self.backend = backend
  92        self.n_clusters = n_clusters
  93        self.clustering_method = clustering_method
  94        self.random_state = random_state
  95        self.conformalize = conformalize
  96        self.residual_sampling = residual_sampling
  97        self.block_size = block_size
  98        self.gmm_components = gmm_components
  99        self.use_rff = use_rff
 100        self.rff_components = rff_components
 101        self.rff_gamma = rff_gamma
 102        self.kernel_approximation = kernel_approximation
 103        self.force_rff_threshold = force_rff_threshold
 104        self.kde_kernel = kde_kernel
 105
 106        if random_state is not None:
 107            np.random.seed(random_state)
 108            if JAX_AVAILABLE:
 109                key = jax.random.PRNGKey(random_state)
 110        # Validate sampling method
 111        valid_sampling_methods = [
 112            "bootstrap",
 113            "kde",
 114            "gmm",
 115            "block-bootstrap",
 116            "me-bootstrap",
 117        ]
 118        if residual_sampling not in valid_sampling_methods:
 119            raise ValueError(
 120                f"residual_sampling must be one of {valid_sampling_methods}"
 121            )
 122        # Validate approximation method
 123        valid_approximations = ["rff", "nystroem"]
 124        if kernel_approximation not in valid_approximations:
 125            raise ValueError(
 126                f"kernel_approximation must be one of {valid_approximations}"
 127            )
 128        # Initialize JAX if using GPU/TPU backend
 129        if backend in ["gpu", "tpu"] and JAX_AVAILABLE:
 130            self._setup_jax_backend()
 131        elif backend in ["gpu", "tpu"] and not JAX_AVAILABLE:
 132            print("JAX not available. Falling back to NumPy backend.")
 133            self.backend = "numpy"
 134        # Initialize attributes that will be set during fitting
 135        self.model = None
 136        self.residuals_ = None
 137        self.X_dist = None
 138        self.is_fitted = False
 139        self.best_params_ = None
 140        self.best_score_ = None
 141        self.cluster_labels_ = None
 142        self.cluster_model_ = None
 143        self.kde_model_ = None
 144        self.gmm_model_ = None
 145        self.scaler_ = None
 146        self.actual_rff_components_ = None
 147        self.actual_use_rff_ = None
 148
 149    def _setup_jax_backend(self):
 150        """Setup JAX backend for GPU/TPU acceleration."""
 151        if not JAX_AVAILABLE:
 152            raise ImportError("JAX is required for GPU/TPU backend")
 153
 154        # JIT compiled distance functions
 155        @jit
 156        def pairwise_sq_dists_jax(X1, X2):
 157            X1_sq = jnp.sum(X1**2, axis=1)[:, jnp.newaxis]
 158            X2_sq = jnp.sum(X2**2, axis=1)[jnp.newaxis, :]
 159            return X1_sq + X2_sq - 2 * X1 @ X2.T
 160
 161        @jit
 162        def cdist_jax(X1, X2):
 163            return vmap(
 164                lambda x: vmap(lambda y: jnp.sqrt(jnp.sum((x - y) ** 2)))(X2)
 165            )(X1)
 166
 167        self._pairwise_sq_dists_jax = pairwise_sq_dists_jax
 168        self._cdist_jax = cdist_jax
 169
 170    def _determine_components(self, n_samples):
 171        """Automatically determine optimal number of components."""
 172        if self.rff_components == "auto":
 173            # Optimized heuristic based on performance results
 174            if n_samples < 500:
 175                return min(50, n_samples)
 176            elif n_samples < 2000:
 177                return min(100, n_samples // 2)
 178            elif n_samples < 5000:
 179                return min(150, n_samples // 3)
 180            elif n_samples < 10000:
 181                return min(200, n_samples // 4)
 182            else:
 183                return min(300, n_samples // 5)
 184        else:
 185            return self.rff_components
 186
 187    def _create_model(self, gamma, alpha, use_rff=None):
 188        """Create the appropriate model based on RFF setting."""
 189        if use_rff is None:
 190            use_rff = self.actual_use_rff_
 191
 192        if use_rff:
 193            # Use kernel approximation with Ridge regression
 194            if self.rff_gamma is not None:
 195                effective_gamma = self.rff_gamma
 196            else:
 197                effective_gamma = gamma
 198            # Determine number of components
 199            n_components = self.actual_rff_components_
 200
 201            if self.kernel_approximation == "rff":
 202                approximator = RBFSampler(
 203                    gamma=effective_gamma,
 204                    n_components=n_components,
 205                    random_state=self.random_state,
 206                )
 207            else:  # nystroem
 208                approximator = Nystroem(
 209                    kernel="rbf",
 210                    gamma=effective_gamma,
 211                    n_components=n_components,
 212                    random_state=self.random_state,
 213                )
 214            # Create pipeline with scaling, approximation, and Ridge
 215            return Pipeline(
 216                [
 217                    ("scaler", StandardScaler()),
 218                    ("approx", approximator),
 219                    ("ridge", Ridge(alpha=alpha)),
 220                ]
 221            )
 222        # Standard KernelRidge
 223        return KernelRidge(kernel=self.kernel, gamma=gamma, alpha=alpha)
 224
 225    def _fit_residual_sampler(self, **kwargs):
 226        """Fit the chosen residual sampling model."""
 227        if self.residuals_ is None or len(self.residuals_) == 0:
 228            raise ValueError("No residuals available for fitting sampler")
 229
 230        if self.residual_sampling == "kde":
 231            kernel_bandwidths = {"bandwidth": np.logspace(-6, 6, 150)}
 232            grid = GridSearchCV(
 233                KernelDensity(kernel=self.kde_kernel, **kwargs),
 234                param_grid=kernel_bandwidths,
 235            )
 236            grid.fit(self.residuals_)
 237            self.kde_model_ = grid.best_estimator_
 238            self.kde_model_.fit(self.residuals_)
 239
 240        elif self.residual_sampling == "gmm":
 241            self.gmm_model_ = GaussianMixture(
 242                n_components=min(self.gmm_components, len(self.residuals_)),
 243                random_state=self.random_state,
 244                covariance_type="full",
 245            )
 246            self.gmm_model_.fit(self.residuals_)
 247
 248    def _sample_residuals(self, num_samples):
 249        """Sample residuals using the chosen method."""
 250        if self.residuals_ is None:
 251            raise ValueError("No residuals available for sampling")
 252
 253        if self.residual_sampling == "bootstrap":
 254            # Original bootstrap method
 255            n = self.residuals_.shape[0]
 256            idx = np.random.choice(n, num_samples, replace=True)
 257            return self.residuals_[idx]
 258
 259        elif self.residual_sampling == "kde":
 260            # Kernel Density Estimation sampling
 261            if self.kde_model_ is None:
 262                raise ValueError(
 263                    "KDE model not fitted. Call _fit_residual_sampler first."
 264                )
 265            # Sample from KDE
 266            return self.kde_model_.sample(num_samples)
 267
 268        elif self.residual_sampling == "gmm":
 269            # Gaussian Mixture Model sampling
 270            if self.gmm_model_ is None:
 271                raise ValueError(
 272                    "GMM model not fitted. Call _fit_residual_sampler first."
 273                )
 274            # Sample from GMM
 275            return self.gmm_model_.sample(num_samples)[0]
 276
 277        elif self.residual_sampling == "me-bootstrap":
 278            meb = MaximumEntropyBootstrap(random_state=self.random_state)
 279            # If residuals are shorter than num_samples, repeat or tile them
 280            residuals = self.residuals_.flatten()
 281            if residuals.shape[0] < num_samples:
 282                # Repeat residuals to reach num_samples
 283                repeats = int(np.ceil(num_samples / residuals.shape[0]))
 284                residuals = np.tile(residuals, repeats)[:num_samples]
 285            else:
 286                residuals = residuals[:num_samples]
 287            meb.fit(residuals)
 288            return meb.sample(1)[:, 0].reshape(-1, 1)
 289
 290        elif self.residual_sampling == "block-bootstrap":
 291            # Block Bootstrap sampling
 292            return bootstrap(
 293                self.residuals_, num_samples, block_size=self.block_size
 294            )
 295
 296        else:
 297            # Should not reach here due to validation in __init__
 298            raise ValueError(
 299                f"Unknown sampling method: {self.residual_sampling}"
 300            )
 301
 302    def _pairwise_sq_dists(self, X1, X2):
 303        """Compute pairwise squared Euclidean distances."""
 304        if self.backend in ["gpu", "tpu"] and JAX_AVAILABLE:
 305            X1_jax = jnp.array(X1)
 306            X2_jax = jnp.array(X2)
 307            result = self._pairwise_sq_dists_jax(X1_jax, X2_jax)
 308            return np.array(result)
 309        else:
 310            X1 = np.atleast_2d(X1)
 311            X2 = np.atleast_2d(X2)
 312            return (
 313                np.sum(X1**2, axis=1)[:, np.newaxis]
 314                + np.sum(X2**2, axis=1)[np.newaxis, :]
 315                - 2 * X1 @ X2.T
 316            )
 317
 318    def _compute_clusters(self, Y):
 319        """Compute cluster labels for stratified splitting."""
 320        if self.clustering_method == "kmeans":
 321            self.cluster_model_ = KMeans(
 322                n_clusters=self.n_clusters,
 323                random_state=self.random_state,
 324                n_init=10,
 325            )
 326        elif self.clustering_method == "gmm":
 327            self.cluster_model_ = GaussianMixture(
 328                n_components=self.n_clusters, random_state=self.random_state
 329            )
 330        else:
 331            raise ValueError("clustering_method must be 'kmeans' or 'gmm'")
 332        self.cluster_model_.fit(Y)
 333        return self.cluster_model_.predict(Y)
 334
 335    def _train_test_split(self, Y, n_train, sequential: bool = False):
 336        """Create train-test split. Stratified by clusters or sequential if specified."""
 337        try:
 338            n_samples = len(Y)
 339        except Exception:
 340            n_samples = Y.shape[0]
 341
 342        if sequential:
 343            # --- Sequential split (no shuffling, preserves temporal order)
 344            train_idx = np.arange(n_train)
 345            test_idx = np.arange(n_train, n_samples)
 346            return train_idx, test_idx
 347
 348        # --- Stratified split (default)
 349        self.cluster_labels_ = self._compute_clusters(Y)
 350        return train_test_split(
 351            np.arange(n_samples),
 352            train_size=n_train,
 353            stratify=self.cluster_labels_,
 354            random_state=self.random_state,
 355        )
 356
 357    def _mmd(self, u, v, kernel_sigma=1):
 358        """Maximum Mean Discrepancy between two distributions."""
 359        if u.ndim == 1:
 360            u = u.reshape(-1, 1)
 361        if v.ndim == 1:
 362            v = v.reshape(-1, 1)
 363
 364        def kmat(A, B):
 365            return np.exp(
 366                -self._pairwise_sq_dists(A, B) / (2 * kernel_sigma**2)
 367            )
 368
 369        return (
 370            np.mean(kmat(u, u)) + np.mean(kmat(v, v)) - 2 * np.mean(kmat(u, v))
 371        )
 372
 373    def _custom_energy_distance(self, u, v):
 374        """Energy distance between two distributions."""
 375        if u.ndim == 1:
 376            u = u.reshape(-1, 1)
 377        if v.ndim == 1:
 378            v = v.reshape(-1, 1)
 379
 380        n, d = u.shape
 381        m = v.shape[0]
 382
 383        if self.backend in ["gpu", "tpu"] and JAX_AVAILABLE:
 384            # JAX implementation
 385            u_jax = jnp.array(u)
 386            v_jax = jnp.array(v)
 387            dist_xx = self._cdist_jax(u_jax, u_jax)
 388            dist_yy = self._cdist_jax(v_jax, v_jax)
 389            dist_xy = self._cdist_jax(u_jax, v_jax)
 390            term1 = 2 * jnp.sum(dist_xy) / (n * m)
 391            term2 = jnp.sum(dist_xx) / (n * n)
 392            term3 = jnp.sum(dist_yy) / (m * m)
 393            return float(term1 - term2 - term3)
 394        else:
 395            # NumPy implementation
 396            dist_xx = cdist(u, u, metric="euclidean")
 397            dist_yy = cdist(v, v, metric="euclidean")
 398            dist_xy = cdist(u, v, metric="euclidean")
 399            term1 = 2 * np.sum(dist_xy) / (n * m)
 400            term2 = np.sum(dist_xx) / (n * n)
 401            term3 = np.sum(dist_yy) / (m * m)
 402            return term1 - term2 - term3
 403
 404    def _generate_pseudo(self, num_samples):
 405        """Generate synthetic data using the fitted model and residuals."""
 406        if not self.is_fitted:
 407            raise ValueError("Model not fitted. Call fit() first.")
 408        X_new = self.X_dist[:num_samples]
 409        # Handle prediction based on model type
 410        if self.actual_use_rff_:
 411            # For RFF pipeline
 412            preds = self.model.predict(X_new)
 413        else:
 414            # For standard KernelRidge
 415            preds = self.model.predict(X_new)
 416        if preds.ndim == 1:
 417            preds = preds.reshape(-1, 1)
 418        # Sample residuals using the chosen method
 419        return preds + self._sample_residuals(preds.shape[0])
 420
 421    def fit(self, Y, n_train=None, metric="energy", n_trials=50, **kwargs):
 422        """
 423        Fit the data generator to match the distribution of Y.
 424
 425        Parameters:
 426        -----------
 427        Y : array-like, shape (n_samples, n_features)
 428            Target multivariate data to emulate
 429        n_train : int, default=None
 430            Number of training samples (default: n_samples // 2)
 431        metric : str, default='energy'
 432            Distance metric for optimization ('energy', 'mmd', or 'wasserstein')
 433        n_trials : int, default=50
 434            Number of Optuna optimization trials
 435        **kwargs : dict
 436            Additional arguments for Optuna optimization
 437
 438        Returns:
 439        --------
 440        self : object
 441            Returns self
 442        """
 443        if Y.ndim == 1:
 444            Y = Y.reshape(-1, 1)
 445
 446        n, d = Y.shape
 447        self.n_features_ = d
 448        # Determine whether to use RFF
 449        if self.use_rff == "auto":
 450            self.actual_use_rff_ = n >= self.force_rff_threshold
 451        else:
 452            self.actual_use_rff_ = self.use_rff
 453        # Auto-enable RFF for large datasets with component determination
 454        if self.actual_use_rff_:
 455            self.actual_rff_components_ = self._determine_components(n)
 456            if self.use_rff == "auto":
 457                print(
 458                    f"Large dataset detected (n={n}). Auto-enabling {self.kernel_approximation.upper()} for scalability."
 459                )
 460
 461        if n_train is None:
 462            n_train = n // 2
 463        # Store the input distribution function
 464        self.X_dist = np.random.normal(0, 1, (n, d))
 465        # Create stratified train-test split
 466        if self.residual_sampling in ("block-bootstrap", "me-bootstrap"):
 467            train_idx, test_idx = self._train_test_split(
 468                Y, n_train, sequential=True
 469            )
 470        else:
 471            train_idx, test_idx = self._train_test_split(
 472                Y, n_train, sequential=False
 473            )
 474        Y_train = Y[train_idx]
 475        Y_test = Y[test_idx]
 476        X_train = self.X_dist[:n_train]
 477
 478        if self.conformalize:
 479
 480            def objective(trial):
 481                sigma = trial.suggest_float("sigma", 0.01, 10, log=True)
 482                lambd = trial.suggest_float("lambd", 1e-5, 1, log=True)
 483                gamma = 1 / (2 * sigma**2)
 484                # Determine proper training set size (50% of training data)
 485                n_proper_train = int(0.5 * len(Y_train))
 486                # Use stratified split for proper training and calibration sets
 487                proper_train_idx, calib_idx = self._train_test_split(
 488                    Y_train, n_proper_train
 489                )
 490                # Split the data
 491                X_proper_train = X_train[proper_train_idx]
 492                Y_proper_train = Y_train[proper_train_idx]
 493                X_calib = X_train[calib_idx]
 494                Y_calib = Y_train[calib_idx]
 495                # Standardize the response (Y) using proper training set statistics
 496                if not hasattr(self, "y_scaler_"):
 497                    self.y_scaler_ = StandardScaler()
 498                    Y_proper_train_scaled = self.y_scaler_.fit_transform(
 499                        Y_proper_train
 500                    )
 501                else:
 502                    Y_proper_train_scaled = self.y_scaler_.transform(
 503                        Y_proper_train
 504                    )
 505                # Create model with current parameters and fit on standardized proper training set
 506                model = self._create_model(gamma, lambd)
 507                model.fit(X_proper_train, Y_proper_train_scaled)
 508                # Get predictions on calibration set and transform back to original scale
 509                preds_calib_scaled = model.predict(X_calib)
 510                if preds_calib_scaled.ndim == 1:
 511                    preds_calib_scaled = preds_calib_scaled.reshape(-1, 1)
 512                # Transform predictions back to original scale
 513                preds_calib = self.y_scaler_.inverse_transform(
 514                    preds_calib_scaled
 515                )
 516                # Calculate residuals on calibration set in original scale
 517                res_calib = Y_calib - preds_calib
 518                # Standardize residuals using calibration set statistics
 519                if res_calib.ndim == 1:
 520                    res_mean = np.mean(res_calib)
 521                    res_std = np.std(res_calib, ddof=1)
 522                    # Avoid division by zero
 523                    res_std = res_std if res_std > 1e-10 else 1.0
 524                    res_calib_standardized = (res_calib - res_mean) / res_std
 525                else:
 526                    res_mean = np.mean(res_calib, axis=0)
 527                    res_std = np.std(res_calib, axis=0, ddof=1)
 528                    # Avoid division by zero
 529                    res_std = np.where(res_std > 1e-10, res_std, 1.0)
 530                    res_calib_standardized = (res_calib - res_mean) / res_std
 531
 532                # Store the calibrated standardized residuals for use in the generation method
 533                calibrated_residuals = (
 534                    res_calib_standardized * res_std + res_mean
 535                )
 536
 537                # Use the existing method to generate pseudo samples with conformal prediction
 538                Y_sim = self._generate_pseudo_with_model(
 539                    model, calibrated_residuals, len(Y_test)
 540                )
 541
 542                # Calculate distance metric
 543                if metric == "energy":
 544                    dist_val = self._custom_energy_distance(Y_test, Y_sim)
 545                elif metric == "mmd":
 546                    dist_val = self._mmd(Y_test, Y_sim)
 547                elif metric == "wasserstein" and d == 1:
 548                    dist_val = stats.wasserstein_distance(
 549                        Y_test.flatten(), Y_sim.flatten()
 550                    )
 551                else:
 552                    raise ValueError("Invalid metric for dimension")
 553
 554                return dist_val
 555
 556        else:
 557
 558            def objective(trial):
 559                sigma = trial.suggest_float("sigma", 0.01, 10, log=True)
 560                lambd = trial.suggest_float("lambd", 1e-5, 1, log=True)
 561                gamma = 1 / (2 * sigma**2)
 562                # Create model with current parameters
 563                model = self._create_model(gamma, lambd)
 564                model.fit(X_train, Y_train)
 565                preds_train = model.predict(X_train)
 566                if preds_train.ndim == 1:
 567                    preds_train = preds_train.reshape(-1, 1)
 568                res = Y_train - preds_train
 569                Y_sim = self._generate_pseudo_with_model(
 570                    model, res, len(Y_test)
 571                )
 572                if metric == "energy":
 573                    dist_val = self._custom_energy_distance(Y_test, Y_sim)
 574                elif metric == "mmd":
 575                    dist_val = self._mmd(Y_test, Y_sim)
 576                elif metric == "wasserstein" and d == 1:
 577                    dist_val = stats.wasserstein_distance(
 578                        Y_test.flatten(), Y_sim.flatten()
 579                    )
 580                else:
 581                    raise ValueError("Invalid metric for dimension")
 582                return dist_val
 583
 584        # Optimize hyperparameters
 585        study = optuna.create_study(direction="minimize")
 586        study.optimize(objective, n_trials=n_trials, **kwargs)
 587        # Store best parameters and fit final model
 588        self.best_params_ = study.best_params
 589        self.best_score_ = study.best_value
 590        sigma = self.best_params_["sigma"]
 591        lambd = self.best_params_["lambd"]
 592        gamma = 1 / (2 * sigma**2)
 593        # Fit final model with best parameters
 594        self.model = self._create_model(gamma, lambd)
 595        self.model.fit(X_train, Y_train)
 596        # Compute residuals
 597        preds_train = self.model.predict(X_train)
 598        if preds_train.ndim == 1:
 599            preds_train = preds_train.reshape(-1, 1)
 600        self.residuals_ = Y_train - preds_train
 601        # Fit the residual sampler
 602        self._fit_residual_sampler()
 603        self.is_fitted = True
 604        # Print final configuration
 605        if self.actual_use_rff_:
 606            print(
 607                f"  Using {self.kernel_approximation.upper()} with {self.actual_rff_components_} components"
 608            )
 609        else:
 610            print(f"  Using standard kernel method")
 611        return self
 612
 613    def _generate_pseudo_with_model(self, model, residuals, num_samples):
 614        """Helper method to generate data with a specific model."""
 615        X_new = self.X_dist[:num_samples]
 616
 617        # Handle prediction based on model type
 618        if hasattr(model, "named_steps"):
 619            # Pipeline (RFF or Nystroem)
 620            preds = model.predict(X_new)
 621        else:
 622            # Standard model
 623            preds = model.predict(X_new)
 624
 625        if preds.ndim == 1:
 626            preds = preds.reshape(-1, 1)
 627
 628        # Temporarily store original state
 629        original_residuals = self.residuals_
 630        original_kde = self.kde_model_
 631        original_gmm = self.gmm_model_
 632
 633        # Set residuals for this model
 634        self.residuals_ = residuals
 635
 636        # Fit sampler with the new residuals
 637        self._fit_residual_sampler()
 638
 639        # Sample residuals
 640        sampled_residuals = self._sample_residuals(num_samples)
 641
 642        # Restore original state
 643        self.residuals_ = original_residuals
 644        self.kde_model_ = original_kde
 645        self.gmm_model_ = original_gmm
 646
 647        return preds + sampled_residuals
 648
 649    def sample(self, n_samples=1):
 650        """
 651        Generate synthetic samples.
 652
 653        Parameters:
 654        -----------
 655        n_samples : int, default=1
 656            Number of samples to generate
 657
 658        Returns:
 659        --------
 660        Y_sim : array, shape (n_samples, n_features)
 661            Generated synthetic data
 662        """
 663        if not self.is_fitted:
 664            raise ValueError("Model not fitted. Call fit() first.")
 665        return self._generate_pseudo(n_samples)
 666
 667    def compare_approximation_methods(self, Y, n_train=None, n_trials=20):
 668        """
 669        Compare different kernel approximation methods.
 670
 671        Parameters:
 672        -----------
 673        Y : array-like
 674            Target data
 675        n_train : int, default=None
 676            Number of training samples
 677        n_trials : int, default=20
 678            Number of optimization trials
 679
 680        Returns:
 681        --------
 682        comparison_results : dict
 683            Comparison results
 684        """
 685        if Y.ndim == 1:
 686            Y = Y.reshape(-1, 1)
 687
 688        print("Comparing Kernel Approximation Methods...")
 689
 690        # Store original settings
 691        original_use_rff = self.use_rff
 692        original_approximation = self.kernel_approximation
 693        original_is_fitted = self.is_fitted
 694
 695        methods = ["rff", "nystroem"]
 696        results = {}
 697
 698        for method in methods:
 699            print(f"\nTesting {method.upper()}...")
 700            self.use_rff = True
 701            self.kernel_approximation = method
 702
 703            start_time = time()
 704            self.fit(Y, n_train=n_train, n_trials=n_trials)
 705            method_time = time() - start_time
 706            method_score = self.best_score_
 707            method_params = self.best_params_
 708
 709            results[method] = {
 710                "time": method_time,
 711                "score": method_score,
 712                "params": method_params,
 713                "components": self.actual_rff_components_,
 714            }
 715
 716        # Test standard method for comparison
 717        print(f"\nTesting Standard Kernel...")
 718        self.use_rff = False
 719        start_time = time()
 720        self.fit(Y, n_train=n_train, n_trials=n_trials)
 721        standard_time = time() - start_time
 722        standard_score = self.best_score_
 723        standard_params = self.best_params_
 724
 725        results["standard"] = {
 726            "time": standard_time,
 727            "score": standard_score,
 728            "params": standard_params,
 729            "components": "N/A",
 730        }
 731
 732        # Restore original settings
 733        self.use_rff = original_use_rff
 734        self.kernel_approximation = original_approximation
 735        self.is_fitted = original_is_fitted
 736
 737        # Print comparison
 738        print("\n" + "=" * 60)
 739        print("KERNEL APPROXIMATION COMPARISON RESULTS")
 740        print("=" * 60)
 741
 742        for method in ["standard"] + methods:
 743            data = results[method]
 744            print(f"\n{method.upper()}:")
 745            print(f"  Time: {data['time']:.2f}s")
 746            print(f"  Score: {data['score']:.6f}")
 747            print(f"  Components: {data['components']}")
 748            if method != "standard":
 749                speedup = standard_time / data["time"]
 750                score_ratio = data["score"] / standard_score
 751                print(f"  Speedup: {speedup:.2f}x")
 752                print(f"  Score Ratio: {score_ratio:.4f}")
 753
 754        return results
 755
 756    def compare_residual_sampling(self, n_samples=1000):
 757        """
 758        Compare different residual sampling methods visually.
 759
 760        Parameters:
 761        -----------
 762        n_samples : int, default=1000
 763            Number of samples to generate for comparison
 764        """
 765        if not self.is_fitted:
 766            raise ValueError("Model not fitted. Call fit() first.")
 767        # Store original sampling method
 768        original_sampling = self.residual_sampling
 769        # Generate samples with different methods
 770        sampling_methods = ["bootstrap", "kde", "gmm"]
 771        samples = {}
 772
 773        for method in sampling_methods:
 774            self.residual_sampling = method
 775            if method == "kde":
 776                self._fit_residual_sampler()
 777            elif method == "gmm":
 778                self._fit_residual_sampler()
 779            samples[method] = self._sample_residuals(n_samples)
 780        # Restore original method
 781        self.residual_sampling = original_sampling
 782        self._fit_residual_sampler()
 783        # Plot comparison
 784        n_dims = self.residuals_.shape[1]
 785        fig, axes = plt.subplots(
 786            n_dims,
 787            len(sampling_methods) + 1,
 788            figsize=(5 * (len(sampling_methods) + 1), 4 * n_dims),
 789        )
 790
 791        if n_dims == 1:
 792            axes = axes.reshape(1, -1)
 793
 794        for dim in range(n_dims):
 795            # Original residuals
 796            axes[dim, 0].hist(
 797                self.residuals_[:, dim], bins=30, alpha=0.7, density=True
 798            )
 799            axes[dim, 0].set_title(f"Original Residuals\nDim {dim+1}")
 800            axes[dim, 0].set_xlabel("Residual Value")
 801            axes[dim, 0].set_ylabel("Density")
 802
 803            # Sampled residuals
 804            for j, method in enumerate(sampling_methods):
 805                col = j + 1
 806                axes[dim, col].hist(
 807                    samples[method][:, dim], bins=30, alpha=0.7, density=True
 808                )
 809                axes[dim, col].set_title(
 810                    f"{method.upper()} Sampling\nDim {dim+1}"
 811                )
 812                axes[dim, col].set_xlabel("Residual Value")
 813                axes[dim, col].set_ylabel("Density")
 814
 815        plt.tight_layout()
 816        plt.show()
 817
 818        return samples
 819
 820    def _perm_test(self, Y_orig, Y_sim, stat_func, n_perm=1000):
 821        """Permutation test for distribution comparison."""
 822        if Y_orig.ndim == 1:
 823            Y_orig = Y_orig.reshape(-1, 1)
 824        if Y_sim.ndim == 1:
 825            Y_sim = Y_sim.reshape(-1, 1)
 826
 827        obs = stat_func(Y_orig, Y_sim)
 828        combined = np.vstack((Y_orig, Y_sim))
 829        n1 = Y_orig.shape[0]
 830        perms = np.zeros(n_perm)
 831
 832        for i in range(n_perm):
 833            idx = np.random.permutation(combined.shape[0])
 834            p1 = combined[idx[:n1]]
 835            p2 = combined[idx[n1:]]
 836            perms[i] = stat_func(p1, p2)
 837
 838        pval = (np.sum(perms >= obs) + 1) / (n_perm + 1)
 839        return obs, pval
 840
 841    def _fisher_z_test(self, r1, r2, n1, n2):
 842        """Fisher z-test for comparing correlation coefficients."""
 843        z1 = np.arctanh(r1)
 844        z2 = np.arctanh(r2)
 845        z = (z1 - z2) / np.sqrt(1 / (n1 - 3) + 1 / (n2 - 3))
 846        p = 2 * (1 - stats.norm.cdf(np.abs(z)))
 847        return z, p
 848
 849    def test_similarity(self, Y_orig, Y_sim, n_perm=1000):
 850        """
 851        Test statistical similarity between original and synthetic data.
 852
 853        Parameters:
 854        -----------
 855        Y_orig : array-like
 856            Original data
 857        Y_sim : array-like
 858            Synthetic data
 859        n_perm : int, default=1000
 860            Number of permutations for permutation tests
 861
 862        Returns:
 863        --------
 864        results : dict
 865            Dictionary containing test results
 866        """
 867        if Y_orig.ndim == 1:
 868            Y_orig = Y_orig.reshape(-1, 1)
 869        if Y_sim.ndim == 1:
 870            Y_sim = Y_sim.reshape(-1, 1)
 871
 872        d = Y_orig.shape[1]
 873        results = {}
 874        # Test 1: Perm with energy
 875        results["energy_perm"] = self._perm_test(
 876            Y_orig, Y_sim, self._custom_energy_distance, n_perm
 877        )
 878        # Test 2: Perm with MMD
 879        results["mmd_perm"] = self._perm_test(
 880            Y_orig, Y_sim, lambda u, v: self._mmd(u, v), n_perm
 881        )
 882
 883        # Test 3: Perm with avg Wasserstein on margins
 884        def avg_wass(u, v):
 885            return np.mean(
 886                [stats.wasserstein_distance(u[:, i], v[:, i]) for i in range(d)]
 887            )
 888
 889        results["avg_wass_perm"] = self._perm_test(
 890            Y_orig, Y_sim, avg_wass, n_perm
 891        )
 892        # Test 4: Min p-value from marginal KS tests
 893        ps_ks = [
 894            stats.ks_2samp(Y_orig[:, i], Y_sim[:, i]).pvalue for i in range(d)
 895        ]
 896        results["min_marginal_ks_p"] = min(ps_ks)
 897        # Test 5: Min p-value from marginal Anderson-Darling tests
 898        ps_ad = [
 899            stats.anderson_ksamp([Y_orig[:, i], Y_sim[:, i]]).significance_level
 900            for i in range(d)
 901        ]
 902        results["min_marginal_ad_p"] = min(ps_ad)
 903        # Test 6: Min p-value from marginal Cramer-von Mises tests
 904        ps_cvm = [
 905            stats.cramervonmises_2samp(Y_orig[:, i], Y_sim[:, i]).pvalue
 906            for i in range(d)
 907        ]
 908        results["min_marginal_cvm_p"] = min(ps_cvm)
 909        # Correlation test: Compare all pairwise correlations
 910        corr_results = {}
 911        pairs = [(i, j) for i in range(d) for j in range(i + 1, d)]
 912        for i, j in pairs:
 913            r_orig = stats.pearsonr(Y_orig[:, i], Y_orig[:, j])[0]
 914            r_sim = stats.pearsonr(Y_sim[:, i], Y_sim[:, j])[0]
 915            z, p = self._fisher_z_test(r_orig, r_sim, len(Y_orig), len(Y_sim))
 916            corr_results[f"corr_dim{i+1}_dim{j+1}"] = (r_orig, r_sim, z, p)
 917        results["corr_tests"] = corr_results
 918        return results
 919
 920    def compare_distributions(self, Y_orig, Y_sim, save_prefix=""):
 921        """
 922        Visual comparison of original and synthetic distributions.
 923
 924        Parameters:
 925        -----------
 926        Y_orig : array-like
 927            Original data
 928        Y_sim : array-like
 929            Synthetic data
 930        save_prefix : str, default=''
 931            Prefix for saving plots
 932        """
 933        if Y_orig.ndim == 1:
 934            Y_orig = Y_orig.reshape(-1, 1)
 935        if Y_sim.ndim == 1:
 936            Y_sim = Y_sim.reshape(-1, 1)
 937
 938        n, d = Y_orig.shape
 939
 940        # Create a figure with subplots for statistical tests
 941        fig, axes = plt.subplots(2, d, figsize=(6 * d, 10))
 942        if d == 1:
 943            axes = axes.reshape(2, 1)
 944
 945        # Statistical test results storage
 946        ks_results = []
 947        ad_results = []
 948
 949        for i in range(d):
 950            # Top row: Histograms with statistical test annotations
 951            ax_hist = axes[0, i]
 952
 953            # Plot histograms
 954            ax_hist.hist(
 955                Y_orig[:, i],
 956                alpha=0.5,
 957                label="Original",
 958                density=True,
 959                bins=20,
 960                color="blue",
 961            )
 962            ax_hist.hist(
 963                Y_sim[:, i],
 964                alpha=0.5,
 965                label="Simulated",
 966                density=True,
 967                bins=20,
 968                color="red",
 969            )
 970
 971            # Perform statistical tests
 972            # Kolmogorov-Smirnov test
 973            ks_stat, ks_pvalue = stats.ks_2samp(Y_orig[:, i], Y_sim[:, i])
 974            ks_results.append((ks_stat, ks_pvalue))
 975
 976            # Anderson-Darling test
 977            ad_result = stats.anderson_ksamp([Y_orig[:, i], Y_sim[:, i]])
 978            ad_stat = ad_result.statistic
 979            ad_critical = ad_result.critical_values
 980            ad_significance = ad_result.significance_level
 981            ad_results.append((ad_stat, ad_significance))
 982
 983            # Add test results to histogram plot
 984            textstr = "\n".join(
 985                (
 986                    f"KS test: p = {ks_pvalue:.4f}",
 987                    f"AD test: p < {ad_significance:.3f}",
 988                    f"AD stat: {ad_stat:.4f}",
 989                )
 990            )
 991            props = dict(boxstyle="round", facecolor="wheat", alpha=0.8)
 992            ax_hist.text(
 993                0.05,
 994                0.95,
 995                textstr,
 996                transform=ax_hist.transAxes,
 997                fontsize=10,
 998                verticalalignment="top",
 999                bbox=props,
1000            )
1001
1002            ax_hist.legend()
1003            ax_hist.set_title(
1004                f"Dimension {i+1} - Histograms with Statistical Tests"
1005            )
1006            ax_hist.set_xlabel("Value")
1007            ax_hist.set_ylabel("Density")
1008
1009            # Bottom row: ECDFs with KS test visualization
1010            ax_ecdf = axes[1, i]
1011
1012            # Compute ECDFs
1013            sorted_orig = np.sort(Y_orig[:, i])
1014            ecdf_orig = np.arange(1, len(sorted_orig) + 1) / len(sorted_orig)
1015            sorted_sim = np.sort(Y_sim[:, i])
1016            ecdf_sim = np.arange(1, len(sorted_sim) + 1) / len(sorted_sim)
1017
1018            # Plot ECDFs
1019            ax_ecdf.step(
1020                sorted_orig,
1021                ecdf_orig,
1022                label="Original",
1023                color="blue",
1024                linewidth=2,
1025            )
1026            ax_ecdf.step(
1027                sorted_sim,
1028                ecdf_sim,
1029                label="Simulated",
1030                color="red",
1031                linewidth=2,
1032            )
1033
1034            # Find the point of maximum difference for KS test
1035            # Combine and sort all values
1036            all_values = np.sort(np.concatenate([sorted_orig, sorted_sim]))
1037            # Compute ECDFs at all points
1038            ecdf_orig_all = np.searchsorted(
1039                sorted_orig, all_values, side="right"
1040            ) / len(sorted_orig)
1041            ecdf_sim_all = np.searchsorted(
1042                sorted_sim, all_values, side="right"
1043            ) / len(sorted_sim)
1044            # Find maximum difference
1045            diff = np.abs(ecdf_orig_all - ecdf_sim_all)
1046            max_idx = np.argmax(diff)
1047            max_x = all_values[max_idx]
1048            max_y1 = ecdf_orig_all[max_idx]
1049            max_y2 = ecdf_sim_all[max_idx]
1050
1051            # Mark the maximum difference point
1052            ax_ecdf.plot(
1053                [max_x, max_x],
1054                [max_y1, max_y2],
1055                "k-",
1056                linewidth=3,
1057                label=f"KS stat: {ks_stat:.4f}",
1058            )
1059            ax_ecdf.plot(max_x, max_y1, "ko", markersize=8)
1060            ax_ecdf.plot(max_x, max_y2, "ko", markersize=8)
1061
1062            ax_ecdf.legend()
1063            ax_ecdf.set_title(f"Dimension {i+1} - ECDFs with KS Statistic")
1064            ax_ecdf.set_xlabel("Value")
1065            ax_ecdf.set_ylabel("ECDF")
1066
1067        plt.tight_layout()
1068        if save_prefix:
1069            plt.savefig(
1070                f"{save_prefix}_statistical_comparison.png",
1071                dpi=300,
1072                bbox_inches="tight",
1073            )
1074        plt.show()
1075
1076        # Print comprehensive test results
1077        print("\n" + "=" * 60)
1078        print("COMPREHENSIVE STATISTICAL TEST RESULTS")
1079        print("=" * 60)
1080
1081        for i in range(d):
1082            ks_stat, ks_pvalue = ks_results[i]
1083            ad_stat, ad_significance = ad_results[i]
1084
1085            print(f"\nDimension {i+1}:")
1086            print(f"  Kolmogorov-Smirnov Test:")
1087            print(f"    Statistic: {ks_stat:.6f}")
1088            print(f"    p-value: {ks_pvalue:.6f}")
1089            print(
1090                f"    Significance: {'Not Significant' if ks_pvalue > 0.05 else 'SIGNIFICANT'}"
1091            )
1092
1093            print(f"  Anderson-Darling Test:")
1094            print(f"    Statistic: {ad_stat:.6f}")
1095            print(f"    Significance level: {ad_significance:.3f}")
1096            print(
1097                f"    Interpretation: {'Distributions differ' if ad_stat > ad_result.critical_values[2] else 'Distributions similar'}"
1098            )
1099
1100        # Create summary plot for all dimensions
1101        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
1102
1103        # KS test p-values across dimensions
1104        ks_pvalues = [result[1] for result in ks_results]
1105        dimensions = list(range(1, d + 1))
1106
1107        bars = ax1.bar(
1108            dimensions,
1109            ks_pvalues,
1110            color=["red" if p < 0.05 else "green" for p in ks_pvalues],
1111        )
1112        ax1.axhline(
1113            y=0.05, color="black", linestyle="--", alpha=0.7, label="α = 0.05"
1114        )
1115        ax1.set_xlabel("Dimension")
1116        ax1.set_ylabel("KS Test p-value")
1117        ax1.set_title("Kolmogorov-Smirnov Test Results\nby Dimension")
1118        ax1.set_xticks(dimensions)
1119        ax1.legend()
1120
1121        # Add value labels on bars
1122        for bar, pvalue in zip(bars, ks_pvalues):
1123            height = bar.get_height()
1124            ax1.text(
1125                bar.get_x() + bar.get_width() / 2.0,
1126                height,
1127                f"{pvalue:.3f}",
1128                ha="center",
1129                va="bottom",
1130            )
1131
1132        # AD test statistics across dimensions
1133        ad_stats = [result[0] for result in ad_results]
1134
1135        bars = ax2.bar(dimensions, ad_stats, color="skyblue")
1136        ax2.set_xlabel("Dimension")
1137        ax2.set_ylabel("AD Test Statistic")
1138        ax2.set_title("Anderson-Darling Test Statistics\nby Dimension")
1139        ax2.set_xticks(dimensions)
1140
1141        # Add value labels on bars
1142        for bar, stat in zip(bars, ad_stats):
1143            height = bar.get_height()
1144            ax2.text(
1145                bar.get_x() + bar.get_width() / 2.0,
1146                height,
1147                f"{stat:.3f}",
1148                ha="center",
1149                va="bottom",
1150            )
1151
1152        plt.tight_layout()
1153        if save_prefix:
1154            plt.savefig(
1155                f"{save_prefix}_test_summary.png", dpi=300, bbox_inches="tight"
1156            )
1157        plt.show()
1158
1159        # Additional: Q-Q plots for each dimension
1160        fig, axes = plt.subplots(1, d, figsize=(5 * d, 5))
1161        if d == 1:
1162            axes = [axes]
1163
1164        for i in range(d):
1165            # Sort data for Q-Q plot
1166            orig_sorted = np.sort(Y_orig[:, i])
1167            sim_sorted = np.sort(Y_sim[:, i])
1168
1169            # Generate theoretical quantiles
1170            n_orig = len(orig_sorted)
1171            n_sim = len(sim_sorted)
1172
1173            # Use smaller set for quantiles to avoid interpolation issues
1174            n_points = min(n_orig, n_sim, 1000)
1175            quantiles = np.linspace(0, 1, n_points)
1176
1177            orig_quantiles = np.quantile(orig_sorted, quantiles)
1178            sim_quantiles = np.quantile(sim_sorted, quantiles)
1179
1180            axes[i].plot(
1181                orig_quantiles, sim_quantiles, "o", alpha=0.6, markersize=3
1182            )
1183            min_val = min(orig_quantiles.min(), sim_quantiles.min())
1184            max_val = max(orig_quantiles.max(), sim_quantiles.max())
1185            axes[i].plot(
1186                [min_val, max_val],
1187                [min_val, max_val],
1188                "r--",
1189                alpha=0.8,
1190                linewidth=2,
1191            )
1192            axes[i].set_xlabel("Original Data Quantiles")
1193            axes[i].set_ylabel("Simulated Data Quantiles")
1194            axes[i].set_title(f"Dimension {i+1} - Q-Q Plot")
1195
1196            # Add correlation coefficient
1197            corr = np.corrcoef(orig_quantiles, sim_quantiles)[0, 1]
1198            axes[i].text(
1199                0.05,
1200                0.95,
1201                f"Corr: {corr:.4f}",
1202                transform=axes[i].transAxes,
1203                bbox=dict(
1204                    boxstyle="round,pad=0.3", facecolor="white", alpha=0.8
1205                ),
1206                verticalalignment="top",
1207            )
1208
1209        plt.tight_layout()
1210        if save_prefix:
1211            plt.savefig(
1212                f"{save_prefix}_qq_plots.png", dpi=300, bbox_inches="tight"
1213            )
1214        plt.show()
1215
1216        return {
1217            "ks_results": ks_results,
1218            "ad_results": ad_results,
1219            "dimensions": d,
1220        }
def fit(self, Y, n_train=None, metric='energy', n_trials=50, **kwargs):
421    def fit(self, Y, n_train=None, metric="energy", n_trials=50, **kwargs):
422        """
423        Fit the data generator to match the distribution of Y.
424
425        Parameters:
426        -----------
427        Y : array-like, shape (n_samples, n_features)
428            Target multivariate data to emulate
429        n_train : int, default=None
430            Number of training samples (default: n_samples // 2)
431        metric : str, default='energy'
432            Distance metric for optimization ('energy', 'mmd', or 'wasserstein')
433        n_trials : int, default=50
434            Number of Optuna optimization trials
435        **kwargs : dict
436            Additional arguments for Optuna optimization
437
438        Returns:
439        --------
440        self : object
441            Returns self
442        """
443        if Y.ndim == 1:
444            Y = Y.reshape(-1, 1)
445
446        n, d = Y.shape
447        self.n_features_ = d
448        # Determine whether to use RFF
449        if self.use_rff == "auto":
450            self.actual_use_rff_ = n >= self.force_rff_threshold
451        else:
452            self.actual_use_rff_ = self.use_rff
453        # Auto-enable RFF for large datasets with component determination
454        if self.actual_use_rff_:
455            self.actual_rff_components_ = self._determine_components(n)
456            if self.use_rff == "auto":
457                print(
458                    f"Large dataset detected (n={n}). Auto-enabling {self.kernel_approximation.upper()} for scalability."
459                )
460
461        if n_train is None:
462            n_train = n // 2
463        # Store the input distribution function
464        self.X_dist = np.random.normal(0, 1, (n, d))
465        # Create stratified train-test split
466        if self.residual_sampling in ("block-bootstrap", "me-bootstrap"):
467            train_idx, test_idx = self._train_test_split(
468                Y, n_train, sequential=True
469            )
470        else:
471            train_idx, test_idx = self._train_test_split(
472                Y, n_train, sequential=False
473            )
474        Y_train = Y[train_idx]
475        Y_test = Y[test_idx]
476        X_train = self.X_dist[:n_train]
477
478        if self.conformalize:
479
480            def objective(trial):
481                sigma = trial.suggest_float("sigma", 0.01, 10, log=True)
482                lambd = trial.suggest_float("lambd", 1e-5, 1, log=True)
483                gamma = 1 / (2 * sigma**2)
484                # Determine proper training set size (50% of training data)
485                n_proper_train = int(0.5 * len(Y_train))
486                # Use stratified split for proper training and calibration sets
487                proper_train_idx, calib_idx = self._train_test_split(
488                    Y_train, n_proper_train
489                )
490                # Split the data
491                X_proper_train = X_train[proper_train_idx]
492                Y_proper_train = Y_train[proper_train_idx]
493                X_calib = X_train[calib_idx]
494                Y_calib = Y_train[calib_idx]
495                # Standardize the response (Y) using proper training set statistics
496                if not hasattr(self, "y_scaler_"):
497                    self.y_scaler_ = StandardScaler()
498                    Y_proper_train_scaled = self.y_scaler_.fit_transform(
499                        Y_proper_train
500                    )
501                else:
502                    Y_proper_train_scaled = self.y_scaler_.transform(
503                        Y_proper_train
504                    )
505                # Create model with current parameters and fit on standardized proper training set
506                model = self._create_model(gamma, lambd)
507                model.fit(X_proper_train, Y_proper_train_scaled)
508                # Get predictions on calibration set and transform back to original scale
509                preds_calib_scaled = model.predict(X_calib)
510                if preds_calib_scaled.ndim == 1:
511                    preds_calib_scaled = preds_calib_scaled.reshape(-1, 1)
512                # Transform predictions back to original scale
513                preds_calib = self.y_scaler_.inverse_transform(
514                    preds_calib_scaled
515                )
516                # Calculate residuals on calibration set in original scale
517                res_calib = Y_calib - preds_calib
518                # Standardize residuals using calibration set statistics
519                if res_calib.ndim == 1:
520                    res_mean = np.mean(res_calib)
521                    res_std = np.std(res_calib, ddof=1)
522                    # Avoid division by zero
523                    res_std = res_std if res_std > 1e-10 else 1.0
524                    res_calib_standardized = (res_calib - res_mean) / res_std
525                else:
526                    res_mean = np.mean(res_calib, axis=0)
527                    res_std = np.std(res_calib, axis=0, ddof=1)
528                    # Avoid division by zero
529                    res_std = np.where(res_std > 1e-10, res_std, 1.0)
530                    res_calib_standardized = (res_calib - res_mean) / res_std
531
532                # Store the calibrated standardized residuals for use in the generation method
533                calibrated_residuals = (
534                    res_calib_standardized * res_std + res_mean
535                )
536
537                # Use the existing method to generate pseudo samples with conformal prediction
538                Y_sim = self._generate_pseudo_with_model(
539                    model, calibrated_residuals, len(Y_test)
540                )
541
542                # Calculate distance metric
543                if metric == "energy":
544                    dist_val = self._custom_energy_distance(Y_test, Y_sim)
545                elif metric == "mmd":
546                    dist_val = self._mmd(Y_test, Y_sim)
547                elif metric == "wasserstein" and d == 1:
548                    dist_val = stats.wasserstein_distance(
549                        Y_test.flatten(), Y_sim.flatten()
550                    )
551                else:
552                    raise ValueError("Invalid metric for dimension")
553
554                return dist_val
555
556        else:
557
558            def objective(trial):
559                sigma = trial.suggest_float("sigma", 0.01, 10, log=True)
560                lambd = trial.suggest_float("lambd", 1e-5, 1, log=True)
561                gamma = 1 / (2 * sigma**2)
562                # Create model with current parameters
563                model = self._create_model(gamma, lambd)
564                model.fit(X_train, Y_train)
565                preds_train = model.predict(X_train)
566                if preds_train.ndim == 1:
567                    preds_train = preds_train.reshape(-1, 1)
568                res = Y_train - preds_train
569                Y_sim = self._generate_pseudo_with_model(
570                    model, res, len(Y_test)
571                )
572                if metric == "energy":
573                    dist_val = self._custom_energy_distance(Y_test, Y_sim)
574                elif metric == "mmd":
575                    dist_val = self._mmd(Y_test, Y_sim)
576                elif metric == "wasserstein" and d == 1:
577                    dist_val = stats.wasserstein_distance(
578                        Y_test.flatten(), Y_sim.flatten()
579                    )
580                else:
581                    raise ValueError("Invalid metric for dimension")
582                return dist_val
583
584        # Optimize hyperparameters
585        study = optuna.create_study(direction="minimize")
586        study.optimize(objective, n_trials=n_trials, **kwargs)
587        # Store best parameters and fit final model
588        self.best_params_ = study.best_params
589        self.best_score_ = study.best_value
590        sigma = self.best_params_["sigma"]
591        lambd = self.best_params_["lambd"]
592        gamma = 1 / (2 * sigma**2)
593        # Fit final model with best parameters
594        self.model = self._create_model(gamma, lambd)
595        self.model.fit(X_train, Y_train)
596        # Compute residuals
597        preds_train = self.model.predict(X_train)
598        if preds_train.ndim == 1:
599            preds_train = preds_train.reshape(-1, 1)
600        self.residuals_ = Y_train - preds_train
601        # Fit the residual sampler
602        self._fit_residual_sampler()
603        self.is_fitted = True
604        # Print final configuration
605        if self.actual_use_rff_:
606            print(
607                f"  Using {self.kernel_approximation.upper()} with {self.actual_rff_components_} components"
608            )
609        else:
610            print(f"  Using standard kernel method")
611        return self

Fit the data generator to match the distribution of Y.

Parameters:

Y : array-like, shape (n_samples, n_features) Target multivariate data to emulate n_train : int, default=None Number of training samples (default: n_samples // 2) metric : str, default='energy' Distance metric for optimization ('energy', 'mmd', or 'wasserstein') n_trials : int, default=50 Number of Optuna optimization trials **kwargs : dict Additional arguments for Optuna optimization

Returns:

self : object Returns self

class EmpiricalCopula:
 14class EmpiricalCopula:
 15    """
 16    Empirical Copula implementation for multivariate dependence modeling.
 17
 18    This class implements a non-parametric copula based on the empirical distribution
 19    of the data. It can fit to multivariate data and generate samples that preserve
 20    the original dependence structure.
 21
 22    The empirical copula is defined as:
 23    C_n(u1, ..., ud) = (1/n) * sum(I(U1i <= u1, ..., Udi <= ud))
 24
 25    where U_ji are the pseudo-observations (ranks) of the original data.
 26    """
 27
 28    def __init__(
 29        self,
 30        smoothing_method: str = "none",
 31        jitter_scale: float = 0.01,
 32        boundary_correction: bool = True,
 33    ):
 34        """
 35        Initialize the Empirical Copula.
 36
 37        Parameters:
 38        -----------
 39        smoothing_method : str, default "none"
 40            Smoothing method for the empirical copula:
 41            - "none": Pure empirical copula (no smoothing)
 42            - "jitter": Add small random noise to avoid ties
 43        jitter_scale : float, default 0.0
 44            Scale of uniform jitter to add to pseudo-observations (0 = no jitter).
 45        boundary_correction : bool, default True
 46            Whether to apply boundary correction for kernel methods.
 47        """
 48        self.smoothing_method = smoothing_method
 49        self.jitter_scale = jitter_scale
 50        self.boundary_correction = boundary_correction
 51        # Fitted attributes
 52        self.is_fitted_ = False
 53        self.n_samples_ = None
 54        self.n_vars_ = None
 55        self.pseudo_observations_ = None
 56        self.original_data_ = None
 57        self.marginal_cdfs_ = []
 58        self.marginal_quantiles_ = []
 59        # For kernel-based methods
 60        self.kde_model_ = None
 61        # For Gaussian mixture model
 62        self.gmm_model_ = None
 63
 64    def fit(self, X: np.ndarray) -> "EmpiricalCopula":
 65        """
 66        Fit the empirical copula to the data.
 67
 68        Parameters:
 69        -----------
 70        X : np.ndarray
 71            Input data of shape (n_samples, n_features) on original scale.
 72
 73        Returns:
 74        --------
 75        self : EmpiricalCopula
 76            Returns self for method chaining.
 77
 78        Raises:
 79        -------
 80        ValueError
 81            If X has inappropriate dimensions.
 82        """
 83        X = np.asarray(X)
 84
 85        if X.ndim != 2:
 86            raise ValueError("X must be a 2D array")
 87        if X.shape[1] < 2:
 88            raise ValueError("X must have at least 2 variables")
 89        if X.shape[0] < 2:
 90            raise ValueError("X must have at least 2 observations")
 91
 92        self.n_samples_, self.n_vars_ = X.shape
 93        self.original_data_ = X.copy()
 94        # Step 1: Convert to pseudo-observations (ranks)
 95        self.pseudo_observations_ = self._to_pseudo_observations(X)
 96        # Step 2: Apply smoothing if requested
 97        if self.smoothing_method != "none":
 98            self.pseudo_observations_ = self._apply_smoothing(
 99                self.pseudo_observations_
100            )
101        # Step 3: Store marginal information for inverse transformation
102        self._fit_marginal_transforms(X)
103        self.is_fitted_ = True
104        # print(f"Empirical copula fitted successfully using '{self.smoothing_method}' method")
105        return self
106
107    def sample(
108        self,
109        n_samples: int = 50,
110        method: str = "bootstrap",
111        kernel: str = "gaussian",
112        n_components: int = 5,
113        covariance_type: str = "full",
114        return_pseudo: bool = False,
115        random_state: Optional[int] = 123,
116        **kwargs,
117    ) -> np.ndarray:
118        """
119        Generate samples from the fitted empirical copula.
120
121        Parameters:
122        -----------
123        n_samples : int, default 50
124            Number of samples to generate.
125        method : str, default "bootstrap"
126            Sampling method:
127            - "bootstrap": Bootstrap resampling from fitted pseudo-observations
128            - "kde": Kernel density estimation sampling (if smoothing was used)
129            - "gmm": Gaussian mixture model sampling
130        kernel : str, default "gaussian"
131            Kernel to use if method is "kde" (default is 'gaussian').
132            Can also be 'tophat'.
133        n_components : int, default 5
134            Number of Gaussian components for GMM method.
135        covariance_type : str, default "full"
136            Type of covariance parameters for GMM method.
137            Options: 'full', 'tied', 'diag', 'spherical'.
138        return_pseudo : bool, default False
139            If True, return samples on [0,1] copula scale.
140            If False, return samples transformed to original scale.
141        random_state : int, optional
142            Random state for reproducible sampling.
143        kwargs : additional arguments for specific sampling methods.
144
145        Returns:
146        --------
147        samples : np.ndarray
148            Generated samples of shape (n_samples, n_features).
149        """
150        if not self.is_fitted_:
151            raise ValueError(
152                "Copula must be fitted before sampling. Call fit() first."
153            )
154
155        if random_state is not None:
156            np.random.seed(random_state)
157        # Generate pseudo-observations
158        if method == "bootstrap":
159            pseudo_samples = self._bootstrap_sample(n_samples)
160        elif method == "kde":
161            pseudo_samples = self._kde_sample(
162                n_samples, kernel=kernel, **kwargs
163            )
164        elif method == "gmm":
165            pseudo_samples = self._gmm_sample(
166                n_samples,
167                n_components=n_components,
168                covariance_type=covariance_type,
169                **kwargs,
170            )
171        else:
172            raise ValueError(
173                f"Unknown sampling method: {method}. "
174                f"Supported methods: 'bootstrap', 'kde', 'gmm'"
175            )
176        if return_pseudo:
177            return pseudo_samples
178        # Transform back to original scale
179        return self._inverse_transform(pseudo_samples)
180
181    def plot_pairwise_pseudo(self):
182        if not self.is_fitted_:
183            raise ValueError("Copula must be fitted before plotting.")
184        plt.figure(figsize=(15, 15))
185        for i in range(self.n_vars_):
186            for j in range(i + 1, self.n_vars_):
187                plt.subplot(
188                    self.n_vars_ - 1,
189                    self.n_vars_ - 1,
190                    i * (self.n_vars_ - 1) + j - i,
191                )
192                plt.scatter(
193                    self.pseudo_observations_[:, i],
194                    self.pseudo_observations_[:, j],
195                    s=5,
196                    alpha=0.5,
197                )
198                plt.xlabel(f"Variable {i+1}")
199                plt.ylabel(f"Variable {j+1}")
200                plt.title(
201                    f"Var{i+1}-Var{j+1} (ρ={self._calculate_spearman_matrix(self.original_data_)[i,j]:.2f})"
202                )
203        plt.tight_layout()
204        plt.show()
205
206    def estimate_tail_dependence(self, threshold=0.05):
207        tail_dep = {}
208        for i in range(self.n_vars_):
209            for j in range(i + 1, self.n_vars_):
210                u = self.pseudo_observations_[:, i]
211                v = self.pseudo_observations_[:, j]
212                lower_tail = (
213                    np.mean((u < threshold) & (v < threshold)) / threshold
214                )
215                upper_tail = (
216                    np.mean((u > 1 - threshold) & (v > 1 - threshold))
217                    / threshold
218                )
219                tail_dep[f"var{i+1}-var{j+1}"] = {
220                    "lower": lower_tail,
221                    "upper": upper_tail,
222                }
223        return tail_dep
224
225    def plot_marginals(self, simulated_samples):
226        import matplotlib.pyplot as plt
227
228        plt.figure(figsize=(15, 5))
229        for j in range(self.n_vars_):
230            plt.subplot(2, self.n_vars_ // 2, j + 1)
231            plt.hist(
232                self.original_data_[:, j],
233                bins=30,
234                alpha=0.5,
235                label="Original",
236                density=True,
237            )
238            plt.hist(
239                simulated_samples[:, j],
240                bins=30,
241                alpha=0.5,
242                label="Simulated",
243                density=True,
244            )
245            plt.title(f"Variable {j+1}")
246            plt.legend()
247        plt.tight_layout()
248        plt.show()
249
250    def validate_fit(
251        self,
252        X_test: Optional[np.ndarray] = None,
253        n_bootstrap: int = 250,
254        alpha: float = 0.05,
255        verbose: bool = True,
256    ) -> Dict:
257        """
258        Validate the fitted empirical copula using comprehensive hypothesis tests.
259
260        This method performs:
261        1. Kolmogorov-Smirnov tests on marginal distributions
262        2. Anderson-Darling tests for marginal goodness-of-fit
263        3. Tests for dependence measures (Spearman rho, Kendall tau, Pearson correlation)
264        4. Cramér-von Mises test for copula goodness-of-fit
265        5. Tests for uniform distribution of pseudo-observations
266
267        Parameters:
268        -----------
269        X_test : np.ndarray, optional
270            Test data for validation. If None, uses training data.
271        n_bootstrap : int, default 1000
272            Number of bootstrap samples for validation.
273        alpha : float, default 0.05
274            Significance level for statistical tests.
275        verbose : bool, default True
276            Whether to print detailed validation results.
277
278        Returns:
279        --------
280        validation_results : dict
281            Dictionary containing all validation test results.
282        """
283        if not self.is_fitted_:
284            raise ValueError("Copula must be fitted before validation.")
285
286        # Use training data if no test data provided
287        if X_test is None:
288            X_test = self.original_data_.copy()
289            if verbose:
290                print("Note: Using training data for validation")
291
292        # Generate bootstrap samples for comparison
293        bootstrap_samples = self.sample(n_samples=n_bootstrap, random_state=42)
294
295        results = {
296            "marginal_tests": {},
297            "dependence_tests": {},
298            "copula_tests": {},
299            "uniformity_tests": {},
300            "summary": {},
301        }
302
303        if verbose:
304            print("\n=== EMPIRICAL COPULA VALIDATION TESTS ===\n")
305
306        # 1. MARGINAL DISTRIBUTION TESTS
307        if verbose:
308            print("1. Marginal Distribution Tests:")
309            print("-" * 35)
310
311        for j in range(self.n_vars_):
312            original_margin = X_test[:, j]
313            simulated_margin = bootstrap_samples[:, j]
314            # Kolmogorov-Smirnov test
315            ks_stat, ks_pvalue = stats.ks_2samp(
316                original_margin, simulated_margin
317            )
318            # Anderson-Darling test (if samples are from same distribution)
319            try:
320                # Combine samples and test if they're from the same distribution
321                combined_data = np.concatenate(
322                    [original_margin, simulated_margin]
323                )
324                combined_mean = np.mean(combined_data)
325                combined_std = np.std(combined_data)
326                # Test both against normal distribution with combined parameters
327                ad_orig = stats.anderson(original_margin, dist="norm")
328                ad_sim = stats.anderson(simulated_margin, dist="norm")
329                # Use the test statistic difference as a measure
330                ad_diff = abs(ad_orig.statistic - ad_sim.statistic)
331                ad_critical = ad_orig.critical_values[
332                    2
333                ]  # 5% significance level
334                ad_pass = ad_diff < ad_critical * 0.5  # Heuristic threshold
335            except:
336                ad_diff = np.nan
337                ad_pass = None
338            # Two-sample t-test for means
339            ttest_stat, ttest_pvalue = stats.ttest_ind(
340                original_margin, simulated_margin
341            )
342            # Levene's test for equal variances
343            levene_stat, levene_pvalue = stats.levene(
344                original_margin, simulated_margin
345            )
346
347            results["marginal_tests"][f"variable_{j+1}"] = {
348                "ks_statistic": ks_stat,
349                "ks_p_value": ks_pvalue,
350                "ks_reject_null": ks_pvalue < alpha,
351                "ad_difference": ad_diff,
352                "ad_pass": ad_pass,
353                "ttest_statistic": ttest_stat,
354                "ttest_p_value": ttest_pvalue,
355                "mean_difference_significant": ttest_pvalue < alpha,
356                "levene_statistic": levene_stat,
357                "levene_p_value": levene_pvalue,
358                "variance_difference_significant": levene_pvalue < alpha,
359            }
360
361            if verbose:
362                status = "FAIL" if ks_pvalue < alpha else "PASS"
363                mean_status = "FAIL" if ttest_pvalue < alpha else "PASS"
364                var_status = "FAIL" if levene_pvalue < alpha else "PASS"
365                print(f"Variable {j+1}:")
366                print(
367                    f"  KS test: statistic={ks_stat:.4f}, p-value={ks_pvalue:.4f} [{status}]"
368                )
369                print(
370                    f"  Mean test: p-value={ttest_pvalue:.4f} [{mean_status}]"
371                )
372                print(
373                    f"  Variance test: p-value={levene_pvalue:.4f} [{var_status}]"
374                )
375
376        # 2. DEPENDENCE STRUCTURE TESTS
377        if verbose:
378            print(f"\n2. Dependence Structure Tests:")
379            print("-" * 32)
380
381        # Calculate dependence measures
382        orig_corr = np.corrcoef(X_test.T)
383        orig_spearman = self._calculate_spearman_matrix(X_test)
384        orig_kendall = self._calculate_kendall_matrix(X_test)
385
386        sim_corr = np.corrcoef(bootstrap_samples.T)
387        sim_spearman = self._calculate_spearman_matrix(bootstrap_samples)
388        sim_kendall = self._calculate_kendall_matrix(bootstrap_samples)
389
390        # Statistical tests for dependence measures
391        dependence_results = {}
392
393        for i in range(self.n_vars_):
394            for j in range(i + 1, self.n_vars_):
395                pair_name = f"var{i+1}_var{j+1}"
396                # Test correlations using Fisher's z-transform
397                r1, r2 = orig_corr[i, j], sim_corr[i, j]
398                n1, n2 = len(X_test), len(bootstrap_samples)
399                # Fisher's z-transform
400                z1 = (
401                    0.5 * np.log((1 + r1) / (1 - r1))
402                    if abs(r1) < 0.999
403                    else np.sign(r1) * 3
404                )
405                z2 = (
406                    0.5 * np.log((1 + r2) / (1 - r2))
407                    if abs(r2) < 0.999
408                    else np.sign(r2) * 3
409                )
410                # Test statistic
411                se = np.sqrt(1 / (n1 - 3) + 1 / (n2 - 3))
412                z_stat = (z1 - z2) / se if se > 0 else 0
413                corr_pvalue = 2 * (1 - stats.norm.cdf(abs(z_stat)))
414                # Spearman and Kendall differences
415                spear_diff = abs(orig_spearman[i, j] - sim_spearman[i, j])
416                kendall_diff = abs(orig_kendall[i, j] - sim_kendall[i, j])
417
418                dependence_results[pair_name] = {
419                    "pearson_original": r1,
420                    "pearson_simulated": r2,
421                    "pearson_z_statistic": z_stat,
422                    "pearson_p_value": corr_pvalue,
423                    "pearson_significant_diff": corr_pvalue < alpha,
424                    "spearman_difference": spear_diff,
425                    "kendall_difference": kendall_diff,
426                    "spearman_large_diff": spear_diff > 0.1,
427                    "kendall_large_diff": kendall_diff > 0.1,
428                }
429
430        results["dependence_tests"] = dependence_results
431
432        if verbose:
433            for pair, tests in dependence_results.items():
434                corr_status = (
435                    "FAIL" if tests["pearson_significant_diff"] else "PASS"
436                )
437                spear_status = (
438                    "WARN" if tests["spearman_large_diff"] else "PASS"
439                )
440                kendall_status = (
441                    "WARN" if tests["kendall_large_diff"] else "PASS"
442                )
443                print(f"{pair.replace('_', '-')}:")
444                print(
445                    f"  Pearson: {tests['pearson_original']:.4f} vs {tests['pearson_simulated']:.4f}, "
446                    f"p-val={tests['pearson_p_value']:.4f} [{corr_status}]"
447                )
448                print(
449                    f"  Spearman diff: {tests['spearman_difference']:.4f} [{spear_status}]"
450                )
451                print(
452                    f"  Kendall diff: {tests['kendall_difference']:.4f} [{kendall_status}]"
453                )
454        # 3. UNIFORMITY TESTS FOR PSEUDO-OBSERVATIONS
455        if verbose:
456            print(f"\n3. Uniformity Tests (Pseudo-Observations):")
457            print("-" * 42)
458        # Test if pseudo-observations are uniform on [0,1]
459        pseudo_test = self._to_pseudo_observations(X_test)
460
461        uniformity_results = {}
462
463        for j in range(self.n_vars_):
464            pseudo_margin = pseudo_test[:, j]
465            # Kolmogorov-Smirnov test against uniform distribution
466            ks_uniform_stat, ks_uniform_pvalue = stats.kstest(
467                pseudo_margin, "uniform"
468            )
469            # Anderson-Darling test for uniformity
470            # Transform to standard normal and test
471            normal_transformed = stats.norm.ppf(
472                np.clip(pseudo_margin, 1e-10, 1 - 1e-10)
473            )
474            ad_result = stats.anderson(normal_transformed, dist="norm")
475
476            # Cramer-von Mises test for uniformity
477            def cvm_uniform(data):
478                """Cramér-von Mises test for uniform distribution."""
479                n = len(data)
480                sorted_data = np.sort(data)
481                i = np.arange(1, n + 1)
482                T = (1.0 / (12 * n)) + np.sum(
483                    ((2 * i - 1) / (2 * n) - sorted_data) ** 2
484                )
485                return T
486
487            cvm_stat = cvm_uniform(pseudo_margin)
488            # Critical value at 5% significance level
489            cvm_critical = 0.461 / (
490                np.sqrt(len(pseudo_margin))
491                + 0.25
492                + 0.75 / np.sqrt(len(pseudo_margin))
493            )
494
495            uniformity_results[f"variable_{j+1}"] = {
496                "ks_uniform_statistic": ks_uniform_stat,
497                "ks_uniform_p_value": ks_uniform_pvalue,
498                "ks_uniform_reject": ks_uniform_pvalue < alpha,
499                "ad_statistic": ad_result.statistic,
500                "ad_critical_5pct": ad_result.critical_values[2],
501                "ad_reject": ad_result.statistic > ad_result.critical_values[2],
502                "cvm_statistic": cvm_stat,
503                "cvm_critical": cvm_critical,
504                "cvm_reject": cvm_stat > cvm_critical,
505            }
506
507            if verbose:
508                ks_status = "FAIL" if ks_uniform_pvalue < alpha else "PASS"
509                ad_status = (
510                    "FAIL"
511                    if ad_result.statistic > ad_result.critical_values[2]
512                    else "PASS"
513                )
514                cvm_status = "FAIL" if cvm_stat > cvm_critical else "PASS"
515                print(f"Variable {j+1}:")
516                print(
517                    f"  KS uniform: stat={ks_uniform_stat:.4f}, p-val={ks_uniform_pvalue:.4f} [{ks_status}]"
518                )
519                print(
520                    f"  AD normal: stat={ad_result.statistic:.4f}, crit={ad_result.critical_values[2]:.4f} [{ad_status}]"
521                )
522                print(
523                    f"  CvM uniform: stat={cvm_stat:.4f}, crit={cvm_critical:.4f} [{cvm_status}]"
524                )
525
526        results["uniformity_tests"] = uniformity_results
527
528        pseudo_orig = self._to_pseudo_observations(X_test)
529        pseudo_sim = self._to_pseudo_observations(bootstrap_samples)
530        # 5. SUMMARY ASSESSMENT
531        if verbose:
532            print(f"\n5. Overall Assessment:")
533            print("-" * 22)
534        # Count various test failures
535        ks_failures = sum(
536            1
537            for j in range(self.n_vars_)
538            if results["marginal_tests"][f"variable_{j+1}"]["ks_reject_null"]
539        )
540
541        mean_failures = sum(
542            1
543            for j in range(self.n_vars_)
544            if results["marginal_tests"][f"variable_{j+1}"][
545                "mean_difference_significant"
546            ]
547        )
548
549        var_failures = sum(
550            1
551            for j in range(self.n_vars_)
552            if results["marginal_tests"][f"variable_{j+1}"][
553                "variance_difference_significant"
554            ]
555        )
556
557        corr_failures = sum(
558            1
559            for tests in dependence_results.values()
560            if tests["pearson_significant_diff"]
561        )
562
563        uniform_failures = sum(
564            1
565            for j in range(self.n_vars_)
566            if results["uniformity_tests"][f"variable_{j+1}"][
567                "ks_uniform_reject"
568            ]
569        )
570        # Calculate average differences
571        avg_spear_diff = np.mean(
572            [
573                tests["spearman_difference"]
574                for tests in dependence_results.values()
575            ]
576        )
577        avg_kendall_diff = np.mean(
578            [
579                tests["kendall_difference"]
580                for tests in dependence_results.values()
581            ]
582        )
583        # Overall quality assessment
584        total_tests = (
585            self.n_vars_ * 3 + len(dependence_results) + self.n_vars_ + 1
586        )
587        total_failures = (
588            ks_failures
589            + mean_failures
590            + var_failures
591            + corr_failures
592            + uniform_failures
593        )
594
595        pass_rate = (total_tests - total_failures) / total_tests * 100
596
597        if pass_rate >= 85 and avg_spear_diff <= 0.05:
598            quality = "Excellent"
599        elif pass_rate >= 70 and avg_spear_diff <= 0.10:
600            quality = "Good"
601        elif pass_rate >= 50 and avg_spear_diff <= 0.15:
602            quality = "Fair"
603        else:
604            quality = "Poor"
605
606        results["summary"] = {
607            "ks_failures": ks_failures,
608            "mean_failures": mean_failures,
609            "variance_failures": var_failures,
610            "correlation_failures": corr_failures,
611            "uniformity_failures": uniform_failures,
612            "total_failures": total_failures,
613            "total_tests": total_tests,
614            "pass_rate": pass_rate,
615            "avg_spearman_difference": avg_spear_diff,
616            "avg_kendall_difference": avg_kendall_diff,
617            "overall_quality": quality,
618        }
619
620        if verbose:
621            print(f"Test Summary ({total_tests} total tests):")
622            print(f"  Marginal KS failures: {ks_failures}/{self.n_vars_}")
623            print(f"  Mean difference failures: {mean_failures}/{self.n_vars_}")
624            print(
625                f"  Variance difference failures: {var_failures}/{self.n_vars_}"
626            )
627            print(
628                f"  Correlation failures: {corr_failures}/{len(dependence_results)}"
629            )
630            print(f"  Uniformity failures: {uniform_failures}/{self.n_vars_}")
631            print(f"  Overall pass rate: {pass_rate:.1f}%")
632            print(f"  Average Spearman difference: {avg_spear_diff:.4f}")
633            print(f"  Average Kendall difference: {avg_kendall_diff:.4f}")
634            print(f"  Overall model quality: {quality}")
635
636        return results
637
638    def _to_pseudo_observations(self, X: np.ndarray) -> np.ndarray:
639        """Convert data to pseudo-observations using empirical CDF."""
640        n_samples, n_vars = X.shape
641        pseudo_obs = np.zeros_like(X)
642
643        for j in range(n_vars):
644            # Rank-based transformation
645            ranks = stats.rankdata(X[:, j], method="average")
646            # Use (rank - 0.5) / n to avoid boundary values
647            pseudo_obs[:, j] = (ranks - 0.5) / n_samples
648
649        return pseudo_obs
650
651    def _apply_smoothing(self, pseudo_obs: np.ndarray) -> np.ndarray:
652        """Apply smoothing to pseudo-observations."""
653        if self.smoothing_method == "jitter":
654            # Add uniform jitter
655            jitter = np.random.uniform(
656                -self.jitter_scale / 2, self.jitter_scale / 2, pseudo_obs.shape
657            )
658            smoothed = pseudo_obs + jitter
659            # Ensure values stay in [0,1]
660            smoothed = np.clip(smoothed, 1e-10, 1 - 1e-10)
661            return smoothed
662        else:
663            return pseudo_obs
664
665    def _fit_marginal_transforms(self, X: np.ndarray) -> None:
666        """Fit marginal transformations for inverse sampling."""
667        self.marginal_cdfs_ = []
668        self.marginal_quantiles_ = []
669
670        for j in range(self.n_vars_):
671            data_col = X[:, j]
672            sorted_data = np.sort(data_col)
673            # Create empirical CDF
674            n = len(sorted_data)
675            cdf_values = np.arange(1, n + 1) / n
676            # Store quantile function (inverse CDF)
677            # Add boundary extrapolation
678            extended_probs = np.concatenate([[0], cdf_values, [1]])
679            extended_data = np.concatenate(
680                [
681                    [sorted_data[0] - (sorted_data[1] - sorted_data[0])],
682                    sorted_data,
683                    [sorted_data[-1] + (sorted_data[-1] - sorted_data[-2])],
684                ]
685            )
686            quantile_func = interp1d(
687                extended_probs,
688                extended_data,
689                kind="linear",
690                bounds_error=False,
691                fill_value=(extended_data[0], extended_data[-1]),
692            )
693            self.marginal_quantiles_.append(quantile_func)
694
695    def _bootstrap_sample(self, n_samples: int) -> np.ndarray:
696        """Generate samples using bootstrap resampling."""
697        # Randomly sample indices with replacement
698        indices = np.random.choice(
699            self.n_samples_, size=n_samples, replace=True
700        )
701        return self.pseudo_observations_[indices]
702
703    def _kde_sample(
704        self, n_samples: int, kernel="gaussian", **kwargs
705    ) -> np.ndarray:
706        """Generate samples using kernel density estimation."""
707        kernel_bandwidths = {"bandwidth": np.logspace(-6, 6, 150)}
708        grid = GridSearchCV(
709            KernelDensity(kernel=kernel, **kwargs), param_grid=kernel_bandwidths
710        )
711        grid.fit(self.pseudo_observations_)
712        self.kde_model_ = grid.best_estimator_
713        return self.kde_model_.sample(n_samples)
714
715    def _gmm_sample(
716        self,
717        n_samples: int,
718        n_components: int = 5,
719        covariance_type: str = "full",
720        **kwargs,
721    ) -> np.ndarray:
722        """
723        Generate samples using Gaussian mixture model.
724
725        Parameters:
726        -----------
727        n_samples : int
728            Number of samples to generate.
729        n_components : int, default 5
730            Number of Gaussian components in the mixture.
731        covariance_type : str, default "full"
732            Type of covariance parameters. Options: 'full', 'tied', 'diag', 'spherical'.
733        **kwargs : additional arguments for GaussianMixture.
734
735        Returns:
736        --------
737        samples : np.ndarray
738            Generated samples on [0,1] copula scale.
739        """
740        # Fit Gaussian mixture model to pseudo-observations
741        gmm = GaussianMixture(
742            n_components=n_components,
743            covariance_type=covariance_type,
744            random_state=kwargs.get("random_state", None),
745            **{k: v for k, v in kwargs.items() if k != "random_state"},
746        )
747
748        # Fit the model
749        gmm.fit(self.pseudo_observations_)
750
751        # Store the fitted model
752        self.gmm_model_ = gmm
753
754        # Generate samples
755        samples, _ = gmm.sample(n_samples)
756
757        # Ensure samples are in [0,1] range (clip if necessary)
758        samples = np.clip(samples, 1e-10, 1 - 1e-10)
759
760        return samples
761
762    def _inverse_transform(self, pseudo_samples: np.ndarray) -> np.ndarray:
763        """Transform pseudo-observations back to original scale."""
764        n_samples, n_vars = pseudo_samples.shape
765        original_samples = np.zeros_like(pseudo_samples)
766
767        for j in range(n_vars):
768            u = pseudo_samples[:, j]
769            # Ensure values are in valid range
770            u = np.clip(u, 1e-10, 1 - 1e-10)
771            original_samples[:, j] = self.marginal_quantiles_[j](u)
772
773        return original_samples
774
775    def _calculate_spearman_matrix(self, X: np.ndarray) -> np.ndarray:
776        """Calculate Spearman rank correlation matrix."""
777        n_vars = X.shape[1]
778        spearman_matrix = np.zeros((n_vars, n_vars))
779
780        for i in range(n_vars):
781            for j in range(n_vars):
782                if i == j:
783                    spearman_matrix[i, j] = 1.0
784                else:
785                    spearman_matrix[i, j], _ = stats.spearmanr(X[:, i], X[:, j])
786
787        return spearman_matrix
788
789    def _calculate_kendall_matrix(self, X: np.ndarray) -> np.ndarray:
790        """Calculate Kendall's tau correlation matrix."""
791        n_vars = X.shape[1]
792        kendall_matrix = np.zeros((n_vars, n_vars))
793
794        for i in range(n_vars):
795            for j in range(n_vars):
796                if i == j:
797                    kendall_matrix[i, j] = 1.0
798                else:
799                    kendall_matrix[i, j], _ = stats.kendalltau(X[:, i], X[:, j])
800
801        return kendall_matrix
802
803    def get_info(self) -> Dict:
804        """Get information about the fitted empirical copula."""
805        if not self.is_fitted_:
806            raise ValueError("Copula must be fitted first.")
807
808        return {
809            "n_samples": self.n_samples_,
810            "n_vars": self.n_vars_,
811            "smoothing_method": self.smoothing_method,
812            "jitter_scale": self.jitter_scale,
813            "boundary_correction": self.boundary_correction,
814            "has_kde_model": self.kde_model_ is not None,
815            "has_gmm_model": self.gmm_model_ is not None,
816        }
817
818    def __repr__(self) -> str:
819        if self.is_fitted_:
820            return (
821                f"EmpiricalCopula(n_samples={self.n_samples_}, n_vars={self.n_vars_}, "
822                f"smoothing='{self.smoothing_method}', fitted=True)"
823            )
824        else:
825            return f"EmpiricalCopula(smoothing='{self.smoothing_method}', fitted=False)"

Empirical Copula implementation for multivariate dependence modeling.

This class implements a non-parametric copula based on the empirical distribution of the data. It can fit to multivariate data and generate samples that preserve the original dependence structure.

The empirical copula is defined as: C_n(u1, ..., ud) = (1/n) * sum(I(U1i <= u1, ..., Udi <= ud))

where U_ji are the pseudo-observations (ranks) of the original data.

def fit(self, X: numpy.ndarray) -> EmpiricalCopula:
 64    def fit(self, X: np.ndarray) -> "EmpiricalCopula":
 65        """
 66        Fit the empirical copula to the data.
 67
 68        Parameters:
 69        -----------
 70        X : np.ndarray
 71            Input data of shape (n_samples, n_features) on original scale.
 72
 73        Returns:
 74        --------
 75        self : EmpiricalCopula
 76            Returns self for method chaining.
 77
 78        Raises:
 79        -------
 80        ValueError
 81            If X has inappropriate dimensions.
 82        """
 83        X = np.asarray(X)
 84
 85        if X.ndim != 2:
 86            raise ValueError("X must be a 2D array")
 87        if X.shape[1] < 2:
 88            raise ValueError("X must have at least 2 variables")
 89        if X.shape[0] < 2:
 90            raise ValueError("X must have at least 2 observations")
 91
 92        self.n_samples_, self.n_vars_ = X.shape
 93        self.original_data_ = X.copy()
 94        # Step 1: Convert to pseudo-observations (ranks)
 95        self.pseudo_observations_ = self._to_pseudo_observations(X)
 96        # Step 2: Apply smoothing if requested
 97        if self.smoothing_method != "none":
 98            self.pseudo_observations_ = self._apply_smoothing(
 99                self.pseudo_observations_
100            )
101        # Step 3: Store marginal information for inverse transformation
102        self._fit_marginal_transforms(X)
103        self.is_fitted_ = True
104        # print(f"Empirical copula fitted successfully using '{self.smoothing_method}' method")
105        return self

Fit the empirical copula to the data.

Parameters:

X : np.ndarray Input data of shape (n_samples, n_features) on original scale.

Returns:

self : EmpiricalCopula Returns self for method chaining.

Raises:

ValueError If X has inappropriate dimensions.

class StratifiedClusteringSubsampling:
 14class StratifiedClusteringSubsampling:
 15    def __init__(
 16        self,
 17        n_components=3,
 18        method=ClusterMethod.GMM,
 19        random_state=None,
 20        **kwargs,
 21    ):
 22        """
 23        Initializes the StratifiedClusteringSubsampling class.
 24
 25        :param n_components: Number of clusters for clustering algorithm. Default is 3.
 26        :param method: Cluster method - 'gmm' or 'kmeans'. Default is GMM.
 27        :param random_state: Seed for random number generator.
 28        :param kwargs: Additional parameters for the clustering algorithms.
 29        """
 30        self.n_components = n_components
 31        self.method = (
 32            method
 33            if isinstance(method, ClusterMethod)
 34            else ClusterMethod(method.lower())
 35        )
 36        self.random_state = random_state
 37        self.kwargs = kwargs
 38
 39        # Initialize the clustering model based on the chosen method
 40        if self.method == ClusterMethod.GMM:
 41            self.cluster_model = GaussianMixture(
 42                n_components=self.n_components,
 43                random_state=self.random_state,
 44                **kwargs,
 45            )
 46        elif self.method == ClusterMethod.KMEANS:
 47            self.cluster_model = KMeans(
 48                n_clusters=self.n_components,
 49                random_state=self.random_state,
 50                **kwargs,
 51            )
 52        else:
 53            raise ValueError(
 54                f"Unsupported method: {method}. Choose 'gmm' or 'kmeans'."
 55            )
 56
 57    def fit(self, data):
 58        """
 59        Fit the clustering model to the given 2D data.
 60
 61        :param data: 2D numpy array where each row is a data point and each column is a feature.
 62        """
 63        # Input validation
 64        if not isinstance(data, np.ndarray):
 65            raise TypeError("Data must be a numpy array")
 66        if data.ndim != 2:
 67            raise ValueError("Data must be 2-dimensional")
 68        if len(data) < self.n_components:
 69            raise ValueError(
 70                f"Number of samples ({len(data)}) must be >= n_components ({self.n_components})"
 71            )
 72
 73        self.cluster_model.fit(data)
 74
 75        # Get cluster labels based on the method
 76        if self.method == ClusterMethod.GMM:
 77            self.cluster_labels = self.cluster_model.predict(data)
 78        else:  # KMEANS
 79            self.cluster_labels = self.cluster_model.labels_
 80
 81        return self
 82
 83    def stratified_sample(self, data, test_size=0.3):
 84        """
 85        Perform stratified sampling based on cluster labels.
 86
 87        :param data: 2D numpy array to sample from.
 88        :param test_size: Proportion of data to be used for testing (default is 30%).
 89        :return: Tuple of (train_data, test_data) where each is a stratified sample.
 90        """
 91        if not hasattr(self, "cluster_labels"):
 92            raise ValueError("Must call fit() before stratified_sample()")
 93        if len(data) != len(self.cluster_labels):
 94            raise ValueError(
 95                "Data length must match fitted cluster labels length"
 96            )
 97        if not 0 < test_size < 1:
 98            raise ValueError("test_size must be between 0 and 1")
 99
100        sss = StratifiedShuffleSplit(
101            n_splits=1, test_size=test_size, random_state=self.random_state
102        )
103
104        for train_index, test_index in sss.split(data, self.cluster_labels):
105            train_data = data[train_index]
106            test_data = data[test_index]
107
108        return train_data, test_data
109
110    def get_cluster_labels(self):
111        """
112        Get the cluster labels assigned by the clustering model.
113
114        :return: 1D numpy array of cluster labels for each data point.
115        """
116        if not hasattr(self, "cluster_labels"):
117            raise ValueError("Must call fit() first")
118        return self.cluster_labels
119
120    def predict(self, data):
121        """
122        Predict the cluster labels for new data using the fitted model.
123
124        :param data: 2D numpy array of new data points.
125        :return: Cluster labels for the new data points.
126        """
127        if self.method == ClusterMethod.GMM:
128            return self.cluster_model.predict(data)
129        else:  # KMEANS
130            return self.cluster_model.predict(data)
131
132    def get_cluster_centers(self):
133        """Get the cluster centers."""
134        if self.method == ClusterMethod.GMM:
135            if not hasattr(self.cluster_model, "means_"):
136                raise ValueError("GMM not fitted yet")
137            return self.cluster_model.means_
138        else:  # KMEANS
139            if not hasattr(self.cluster_model, "cluster_centers_"):
140                raise ValueError("KMeans not fitted yet")
141            return self.cluster_model.cluster_centers_
142
143    def get_cluster_proportions(self):
144        """Get the proportion of data points in each cluster."""
145        if not hasattr(self, "cluster_labels"):
146            raise ValueError("Must call fit() first")
147        unique, counts = np.unique(self.cluster_labels, return_counts=True)
148        return counts / len(self.cluster_labels)
149
150    def score_samples(self, data):
151        """
152        Get the model scores for samples.
153        For GMM: log-likelihood of samples
154        For KMeans: negative of inertia (distance to closest cluster center)
155        """
156        if self.method == ClusterMethod.GMM:
157            return self.cluster_model.score_samples(data)
158        else:
159            # For KMeans, return negative distances to cluster centers
160            return -self.cluster_model.transform(data).min(axis=1)
161
162    def get_model_params(self):
163        """Get the parameters of the fitted model."""
164        return self.cluster_model.get_params()
165
166    def set_method(self, method):
167        """
168        Change the clustering method after initialization (will require re-fitting).
169
170        :param method: New cluster method ('gmm' or 'kmeans')
171        """
172        old_method = self.method
173        self.method = (
174            method
175            if isinstance(method, ClusterMethod)
176            else ClusterMethod(method.lower())
177        )
178
179        if self.method != old_method:
180            # Reinitialize the model with the new method
181            if self.method == ClusterMethod.GMM:
182                self.cluster_model = GaussianMixture(
183                    n_components=self.n_components,
184                    random_state=self.random_state,
185                    **self.kwargs,
186                )
187            else:  # KMEANS
188                self.cluster_model = KMeans(
189                    n_clusters=self.n_components,
190                    random_state=self.random_state,
191                    **self.kwargs,
192                )
193
194            # Remove fitted attributes to force re-fitting
195            if hasattr(self, "cluster_labels"):
196                del self.cluster_labels
def fit(self, data):
57    def fit(self, data):
58        """
59        Fit the clustering model to the given 2D data.
60
61        :param data: 2D numpy array where each row is a data point and each column is a feature.
62        """
63        # Input validation
64        if not isinstance(data, np.ndarray):
65            raise TypeError("Data must be a numpy array")
66        if data.ndim != 2:
67            raise ValueError("Data must be 2-dimensional")
68        if len(data) < self.n_components:
69            raise ValueError(
70                f"Number of samples ({len(data)}) must be >= n_components ({self.n_components})"
71            )
72
73        self.cluster_model.fit(data)
74
75        # Get cluster labels based on the method
76        if self.method == ClusterMethod.GMM:
77            self.cluster_labels = self.cluster_model.predict(data)
78        else:  # KMEANS
79            self.cluster_labels = self.cluster_model.labels_
80
81        return self

Fit the clustering model to the given 2D data.

Parameters
  • data: 2D numpy array where each row is a data point and each column is a feature.
def predict(self, data):
120    def predict(self, data):
121        """
122        Predict the cluster labels for new data using the fitted model.
123
124        :param data: 2D numpy array of new data points.
125        :return: Cluster labels for the new data points.
126        """
127        if self.method == ClusterMethod.GMM:
128            return self.cluster_model.predict(data)
129        else:  # KMEANS
130            return self.cluster_model.predict(data)

Predict the cluster labels for new data using the fitted model.

Parameters
  • data: 2D numpy array of new data points.
Returns

Cluster labels for the new data points.

class SubSampler:
 6class SubSampler:
 7    """Subsampling class.
 8
 9    Attributes:
10
11       y: array-like, shape = [n_samples]
12           Target values.
13
14       row_sample: double
15           subsampling fraction
16
17       n_samples: int
18            subsampling by using the number of rows (supersedes row_sample)
19
20       seed: int
21           reproductibility seed
22
23       n_jobs: int
24            number of jobs to run in parallel
25
26       verbose: bool
27            print progress messages and bars
28    """
29
30    def __init__(
31        self,
32        y,
33        row_sample=0.8,
34        n_samples=None,
35        seed=123,
36        n_jobs=None,
37        verbose=False,
38    ):
39        self.y = y
40        self.n_samples = n_samples
41        if self.n_samples is None:
42            assert (
43                row_sample < 1 and row_sample >= 0
44            ), "'row_sample' must be provided, plus < 1 and >= 0"
45            self.row_sample = row_sample
46        else:
47            assert self.n_samples < len(y), "'n_samples' must be < len(y)"
48            self.row_sample = self.n_samples / len(y)
49        self.seed = seed
50        self.indices = None
51        self.n_jobs = n_jobs
52        self.verbose = verbose
53
54    def subsample(self):
55        """Returns indices of subsampled input data.
56
57        Examples:
58
59        <ul>
60            <li> <a href="https://github.com/Techtonique/nnetsauce/blob/master/nnetsauce/demo/thierrymoudiki_20240105_subsampling.ipynb">20240105_subsampling.ipynb</a> </li>
61            <li> <a href="https://github.com/Techtonique/nnetsauce/blob/master/nnetsauce/demo/thierrymoudiki_20240131_subsampling_nsamples.ipynb">20240131_subsampling_nsamples.ipynb</a> </li>
62        </ul>
63
64        """
65        self.indices = dosubsample(
66            y=self.y,
67            row_sample=self.row_sample,
68            seed=self.seed,
69            n_jobs=self.n_jobs,
70            verbose=self.verbose,
71        )
72        return self.indices

Subsampling class.

Attributes:

y: array-like, shape = [n_samples] Target values.

row_sample: double subsampling fraction

n_samples: int subsampling by using the number of rows (supersedes row_sample)

seed: int reproductibility seed

n_jobs: int number of jobs to run in parallel

verbose: bool print progress messages and bars

def subsample(self):
54    def subsample(self):
55        """Returns indices of subsampled input data.
56
57        Examples:
58
59        <ul>
60            <li> <a href="https://github.com/Techtonique/nnetsauce/blob/master/nnetsauce/demo/thierrymoudiki_20240105_subsampling.ipynb">20240105_subsampling.ipynb</a> </li>
61            <li> <a href="https://github.com/Techtonique/nnetsauce/blob/master/nnetsauce/demo/thierrymoudiki_20240131_subsampling_nsamples.ipynb">20240131_subsampling_nsamples.ipynb</a> </li>
62        </ul>
63
64        """
65        self.indices = dosubsample(
66            y=self.y,
67            row_sample=self.row_sample,
68            seed=self.seed,
69            n_jobs=self.n_jobs,
70            verbose=self.verbose,
71        )
72        return self.indices

Returns indices of subsampled input data.

Examples:

class SmartHealthSimulator:
 12class SmartHealthSimulator:
 13    """
 14    Simulates a synthetic, multimodal time series dataset resembling wearable, environmental,
 15    behavioral, and self-reported health data over time. Includes numeric, categorical, and text data.
 16
 17    The simulator generates realistic daily records including:
 18    - Heart rate
 19    - Steps
 20    - Skin and ambient temperature
 21    - Activity label (rest, walk, exercise)
 22    - Mood score (1-5)
 23    - Mood notes (short texts)
 24    - Air quality index
 25    - Sleep quality (dependent variable)
 26
 27    Includes methods for plotting time series, distributions, relationships, and text-based visualizations.
 28
 29    Examples
 30    --------
 31    >>> sim = SmartHealthSimulator(days=180, seed=42)
 32    >>> print(sim.data.head())
 33    >>> sim.plot_time_series()
 34    >>> sim.plot_mood_sleep()
 35    >>> sim.plot_activity_distribution()
 36    >>> sim.plot_mood_wordcloud()
 37    """
 38
 39    def __init__(self, days: int = 180, seed: int = 123):
 40        """
 41        Create a new simulator instance and generate synthetic data
 42
 43        Parameters
 44        ----------
 45        days : int, optional
 46            Number of days to simulate (default: 180)
 47        seed : int, optional
 48            Random seed for reproducibility (default: 123)
 49        """
 50        self.seed = seed
 51        self._n_days = days
 52        self.data = None
 53        self._generate_data()
 54
 55    def _generate_data(self) -> None:
 56        """Internal data generation function"""
 57        np.random.seed(self.seed)
 58        n = self._n_days
 59
 60        # Generate timestamps
 61        start_date = datetime(2025, 1, 1)
 62        timestamps = [start_date + timedelta(days=i) for i in range(n)]
 63
 64        # Generate numeric variables
 65        hr_mean = np.round(np.random.normal(70, 10, n), 1)
 66        steps = np.round(np.random.normal(7000, 3000, n)).astype(int)
 67        skin_temp = np.round(np.random.normal(36.5, 0.4, n), 1)
 68        ambient_temp = np.round(np.random.normal(23, 3, n), 1)
 69        air_quality_index = np.round(np.random.uniform(20, 120, n)).astype(int)
 70
 71        # Generate activity labels based on steps
 72        activity_label = pd.cut(
 73            steps,
 74            bins=[-np.inf, 3000, 7000, np.inf],
 75            labels=["rest", "walk", "exercise"],
 76        )
 77
 78        # Generate mood score with dependencies
 79        mood_score = (
 80            3
 81            + 0.001 * (steps - 7000)
 82            - 0.01 * (air_quality_index - 50)
 83            + np.random.normal(0, 0.5, n)
 84        )
 85        mood_score = np.round(np.clip(mood_score, 1, 5)).astype(int)
 86
 87        # Generate mood notes
 88        mood_phrases = [
 89            "Felt great today.",
 90            "Very tired.",
 91            "Worked out hard.",
 92            "Anxious and stressed.",
 93            "Calm and productive day.",
 94            "Slept poorly.",
 95            "Long day at work.",
 96        ]
 97
 98        mood_note = []
 99        for ms in mood_score:
100            if ms >= 4:
101                mood_note.append(
102                    np.random.choice(
103                        [mood_phrases[0], mood_phrases[2], mood_phrases[4]]
104                    )
105                )
106            elif ms <= 2:
107                mood_note.append(
108                    np.random.choice(
109                        [mood_phrases[1], mood_phrases[3], mood_phrases[5]]
110                    )
111                )
112            else:
113                mood_note.append(np.random.choice(mood_phrases))
114
115        # Generate sleep quality (dependent variable)
116        activity_penalty = np.where(activity_label == "exercise", -10, 0)
117
118        sleep_quality = (
119            100
120            - 0.2 * hr_mean
121            + 0.01 * steps
122            + 5 * (mood_score - 3)
123            - 0.3 * air_quality_index
124            + activity_penalty
125            + np.random.normal(0, 5, n)
126        )
127        sleep_quality = np.round(np.clip(sleep_quality, 0, 100), 1)
128
129        # Create DataFrame
130        self.data = pd.DataFrame(
131            {
132                "timestamp": timestamps,
133                "hr_mean": hr_mean,
134                "steps": steps,
135                "skin_temp": skin_temp,
136                "ambient_temp": ambient_temp,
137                "activity_label": activity_label,
138                "mood_score": mood_score,
139                "mood_note": mood_note,
140                "air_quality_index": air_quality_index,
141                "sleep_quality": sleep_quality,
142            }
143        )
144
145    def plot_time_series(self, vars: Optional[List[str]] = None) -> plt.Figure:
146        """
147        Plot time series of selected numeric variables
148
149        Parameters
150        ----------
151        vars : list of str, optional
152            Column names to plot (default: ["hr_mean", "steps", "sleep_quality"])
153
154        Returns
155        -------
156        matplotlib.figure.Figure
157            The time series plot
158        """
159        if vars is None:
160            vars = ["hr_mean", "steps", "sleep_quality"]
161
162        fig, axes = plt.subplots(len(vars), 1, figsize=(12, 3 * len(vars)))
163        if len(vars) == 1:
164            axes = [axes]
165
166        for i, var in enumerate(vars):
167            axes[i].plot(self.data["timestamp"], self.data[var], linewidth=2)
168            axes[i].set_title(f"Time Series of {var}")
169            axes[i].set_xlabel("Date")
170            axes[i].set_ylabel(var)
171            axes[i].tick_params(axis="x", rotation=45)
172
173        plt.tight_layout()
174        return fig
175
176    def plot_mood_sleep(self) -> plt.Figure:
177        """
178        Plot the relationship between mood score and sleep quality
179
180        Returns
181        -------
182        matplotlib.figure.Figure
183            Scatter plot with regression line
184        """
185        fig, ax = plt.subplots(figsize=(10, 6))
186
187        # Create jitter for discrete mood scores
188        mood_jitter = self.data["mood_score"] + np.random.normal(
189            0, 0.1, len(self.data)
190        )
191
192        sns.regplot(
193            x=mood_jitter,
194            y=self.data["sleep_quality"],
195            scatter_kws={"alpha": 0.6, "color": "steelblue"},
196            line_kws={"color": "darkred"},
197            ax=ax,
198        )
199
200        ax.set_xlabel("Mood Score")
201        ax.set_ylabel("Sleep Quality")
202        ax.set_title("Sleep Quality vs. Mood Score")
203        ax.set_xticks(range(1, 6))
204        ax.grid(True, alpha=0.3)
205
206        plt.tight_layout()
207        return fig
208
209    def plot_activity_distribution(self) -> plt.Figure:
210        """
211        Plot the distribution of activity labels
212
213        Returns
214        -------
215        matplotlib.figure.Figure
216            Bar chart of activity distribution
217        """
218        fig, ax = plt.subplots(figsize=(10, 6))
219
220        activity_counts = self.data["activity_label"].value_counts()
221        colors = plt.cm.Set2(np.linspace(0, 1, len(activity_counts)))
222
223        bars = ax.bar(
224            activity_counts.index, activity_counts.values, color=colors
225        )
226        ax.set_xlabel("Activity")
227        ax.set_ylabel("Count")
228        ax.set_title("Activity Label Distribution")
229
230        # Add value labels on bars
231        for bar in bars:
232            height = bar.get_height()
233            ax.text(
234                bar.get_x() + bar.get_width() / 2.0,
235                height,
236                f"{int(height)}",
237                ha="center",
238                va="bottom",
239            )
240
241        plt.tight_layout()
242        return fig
243
244    def plot_mood_wordcloud(self) -> plt.Figure:
245        """
246        Create a word cloud from the self-reported mood notes
247
248        Returns
249        -------
250        matplotlib.figure.Figure
251            Word cloud visualization
252        """
253        try:
254            # Combine all mood notes
255            text = " ".join(self.data["mood_note"].tolist())
256
257            # Generate word cloud
258            wordcloud = WordCloud(
259                width=800,
260                height=400,
261                background_color="white",
262                colormap="viridis",
263                max_words=100,
264            ).generate(text)
265
266            fig, ax = plt.subplots(figsize=(12, 6))
267            ax.imshow(wordcloud, interpolation="bilinear")
268            ax.set_title("Mood Notes Word Cloud", fontsize=16)
269            ax.axis("off")
270
271            plt.tight_layout()
272            return fig
273
274        except ImportError:
275            warnings.warn(
276                "wordcloud package not installed. Install with: pip install wordcloud"
277            )
278            return None
279
280    def __repr__(self) -> str:
281        return f"SmartHealthSimulator(days={self._n_days}, seed={self.seed})"
282
283    def __str__(self) -> str:
284        return f"SmartHealthSimulator with {len(self.data)} days of synthetic health data"

Simulates a synthetic, multimodal time series dataset resembling wearable, environmental, behavioral, and self-reported health data over time. Includes numeric, categorical, and text data.

The simulator generates realistic daily records including:

  • Heart rate
  • Steps
  • Skin and ambient temperature
  • Activity label (rest, walk, exercise)
  • Mood score (1-5)
  • Mood notes (short texts)
  • Air quality index
  • Sleep quality (dependent variable)

Includes methods for plotting time series, distributions, relationships, and text-based visualizations.

Examples

>>> sim = SmartHealthSimulator(days=180, seed=42)
>>> print(sim.data.head())
>>> sim.plot_time_series()
>>> sim.plot_mood_sleep()
>>> sim.plot_activity_distribution()
>>> sim.plot_mood_wordcloud()
class DistanceMetrics:
 13class DistanceMetrics:
 14    def __init__(self, vector, matrix):
 15        self.vector = np.array(vector)
 16        self.matrix = np.array(matrix)
 17
 18    def euclidean_distance(self):
 19        """Euclidean (L2) Distance between vector and each row of the matrix."""
 20        return np.linalg.norm(self.matrix - self.vector, axis=1)
 21
 22    def manhattan_distance(self):
 23        """Manhattan (L1) Distance between vector and each row of the matrix."""
 24        return np.sum(np.abs(self.matrix - self.vector), axis=1)
 25
 26    def cosine_distance(self):
 27        """Cosine Distance between vector and each row of the matrix."""
 28        similarities = cosine_similarity(
 29            self.vector.reshape(1, -1), self.matrix
 30        )
 31        return 1 - similarities.flatten()
 32
 33    def mahalanobis_distance(self):
 34        """Mahalanobis Distance between vector and each row of the matrix."""
 35        cov_matrix = np.cov(self.matrix.T)
 36        inv_cov_matrix = np.linalg.inv(cov_matrix)
 37        return [
 38            distance.mahalanobis(self.vector, m, inv_cov_matrix)
 39            for m in self.matrix
 40        ]
 41
 42    def chebyshev_distance(self):
 43        """Chebyshev Distance (Maximum absolute difference)."""
 44        return np.max(np.abs(self.matrix - self.vector), axis=1)
 45
 46    def hamming_distance(self):
 47        """Hamming Distance between vector and each row of the matrix (for binary data)."""
 48        return np.sum(self.matrix != self.vector, axis=1)
 49
 50    def jaccard_distance(self):
 51        """Jaccard Distance between vector and each row of the matrix (for binary data)."""
 52        return [
 53            1 - jaccard_score(self.vector, m, average="binary")
 54            for m in self.matrix
 55        ]
 56
 57    def weighted_euclidean_distance(self, weights):
 58        """Weighted Euclidean Distance between vector and each row of the matrix."""
 59        weights = np.array(weights)
 60        return np.sqrt(
 61            np.sum(weights * (self.matrix - self.vector) ** 2, axis=1)
 62        )
 63
 64    def kullback_leibler_divergence(self, P, Q):
 65        """Kullback-Leibler Divergence between two distributions P and Q."""
 66        return entropy(P, Q)
 67
 68    def wasserstein_distance(self, distribution_1, distribution_2):
 69        """Wasserstein Distance (Earth Mover's Distance) between two distributions."""
 70        return wasserstein_distance(distribution_1, distribution_2)
 71
 72    def pearson_correlation(self):
 73        """Pearson Correlation between the vector and each row of the matrix."""
 74        return [pearsonr(self.vector, m)[0] for m in self.matrix]
 75
 76    def jensen_shannon_divergence(self, P, Q):
 77        """Jensen-Shannon Divergence between two distributions."""
 78        M = 0.5 * (P + Q)
 79        return 0.5 * (kl_div(P, M).sum() + kl_div(Q, M).sum())
 80
 81    def total_variation_distance(self, P, Q):
 82        """Total Variation Distance between two distributions."""
 83        return 0.5 * np.sum(np.abs(P - Q))
 84
 85    def qqplot_with_summary(
 86        self, data1, data2, label1="Sample 1", label2="Sample 2"
 87    ):
 88        data1 = np.asarray(data1)
 89        data2 = np.asarray(data2)
 90
 91        # Remove NaN values
 92        data1 = data1[~np.isnan(data1)]
 93        data2 = data2[~np.isnan(data2)]
 94
 95        # Q–Q plot
 96        n_quantiles = min(len(data1), len(data2))
 97        quantiles1 = np.percentile(data1, np.linspace(0, 100, n_quantiles))
 98        quantiles2 = np.percentile(data2, np.linspace(0, 100, n_quantiles))
 99
100        plt.figure(figsize=(6, 6))
101        plt.scatter(quantiles1, quantiles2, alpha=0.7)
102        min_val = min(quantiles1.min(), quantiles2.min())
103        max_val = max(quantiles1.max(), quantiles2.max())
104        plt.plot([min_val, max_val], [min_val, max_val], "r--", label="y = x")
105        plt.xlabel(label1)
106        plt.ylabel(label2)
107        plt.title("Q–Q Plot")
108        plt.legend()
109        plt.grid(True)
110        plt.show()
111
112        # Descriptive stats
113        mean1, mean2 = np.mean(data1), np.mean(data2)
114        std1, std2 = np.std(data1, ddof=1), np.std(data2, ddof=1)
115        n1, n2 = len(data1), len(data2)
116
117        # Kolmogorov–Smirnov test
118        ks_stat, ks_p = stats.ks_2samp(data1, data2)
119
120        # Anderson–Darling test (two-sample)
121        ad_result = stats.anderson_ksamp([data1, data2])
122        ad_stat = ad_result.statistic
123        ad_p = ad_result.significance_level / 100  # convert % to proportion
124
125        # Quantile correlation
126        corr = np.corrcoef(quantiles1, quantiles2)[0, 1]
127
128        # Summary table
129        summary = pd.DataFrame(
130            {
131                "Statistic": [
132                    "Sample size",
133                    "Mean",
134                    "Std. deviation",
135                    "KS statistic",
136                    "KS p-value",
137                    "AD statistic",
138                    "AD p-value",
139                    "Quantile correlation",
140                ],
141                label1: [n1, mean1, std1, ks_stat, ks_p, ad_stat, ad_p, corr],
142                label2: [n2, mean2, std2, "", "", "", "", ""],
143            }
144        )
145
146        return summary
class MaximumEntropyBootstrap:
 26class MaximumEntropyBootstrap:
 27    """
 28    Maximum Entropy Bootstrap for time series inference with plotting and hypothesis testing.
 29    """
 30
 31    def __init__(self, trim: float = 0.10, random_state: Optional[int] = None):
 32        self.trim = trim
 33        self.random_state = random_state
 34        if random_state is not None:
 35            np.random.seed(random_state)
 36
 37        # Storage for intermediate results
 38        self.order_stats_ = None
 39        self.ordering_index_ = None
 40        self.intermediate_points_ = None
 41        self.interval_means_ = None
 42        self.limits_ = None
 43        self.original_series_ = None
 44
 45    def _calculate_trimmed_mean_diff(self, x: np.ndarray) -> float:
 46        """Calculate trimmed mean of consecutive differences."""
 47        diffs = np.diff(x)
 48        if len(diffs) == 0:
 49            return 0.0
 50
 51        lower_bound = np.percentile(diffs, self.trim * 100)
 52        upper_bound = np.percentile(diffs, (1 - self.trim) * 100)
 53        trimmed_diffs = diffs[(diffs >= lower_bound) & (diffs <= upper_bound)]
 54
 55        return (
 56            np.mean(trimmed_diffs) if len(trimmed_diffs) > 0 else np.mean(diffs)
 57        )
 58
 59    def _calculate_interval_means(self, order_stats: np.ndarray) -> np.ndarray:
 60        """Calculate means for each interval using mean-preserving constraint."""
 61        T = len(order_stats)
 62        means = np.zeros(T)
 63
 64        means[0] = 0.75 * order_stats[0] + 0.25 * order_stats[1]
 65
 66        for k in range(1, T - 1):
 67            means[k] = (
 68                0.25 * order_stats[k - 1]
 69                + 0.50 * order_stats[k]
 70                + 0.25 * order_stats[k + 1]
 71            )
 72
 73        means[T - 1] = 0.25 * order_stats[T - 2] + 0.75 * order_stats[T - 1]
 74
 75        return means
 76
 77    def fit(
 78        self, x: Union[np.ndarray, List, pd.Series]
 79    ) -> "MaximumEntropyBootstrap":
 80        """Fit the ME bootstrap to the time series."""
 81        x = np.asarray(x)
 82        if len(x) < 3:
 83            raise ValueError("Time series must have at least 3 observations")
 84
 85        self.original_series_ = x.copy()
 86        T = len(x)
 87
 88        # Step 1: Sort data and store ordering index
 89        self.ordering_index_ = np.argsort(x)
 90        self.order_stats_ = x[self.ordering_index_]
 91
 92        # Step 2: Compute intermediate points
 93        self.intermediate_points_ = (
 94            self.order_stats_[:-1] + self.order_stats_[1:]
 95        ) / 2
 96
 97        # Step 3: Compute limits for tails
 98        m_trim = self._calculate_trimmed_mean_diff(x)
 99        z0 = self.order_stats_[0] - m_trim
100        zT = self.order_stats_[-1] + m_trim
101
102        self.limits_ = (z0, zT)
103
104        # Step 4: Compute interval means
105        self.interval_means_ = self._calculate_interval_means(self.order_stats_)
106
107        return self
108
109    def _generate_me_quantiles(self, size: int) -> np.ndarray:
110        """Generate quantiles from maximum entropy density."""
111        z0, zT = self.limits_
112        all_z_points = np.concatenate([[z0], self.intermediate_points_, [zT]])
113
114        u = np.random.uniform(0, 1, size)
115        quantiles = np.zeros(size)
116        n_intervals = len(all_z_points) - 1
117
118        for i in range(size):
119            interval_idx = int(u[i] * n_intervals)
120            interval_idx = min(interval_idx, n_intervals - 1)
121
122            interval_start = all_z_points[interval_idx]
123            interval_end = all_z_points[interval_idx + 1]
124            interval_frac = (u[i] * n_intervals) - interval_idx
125
126            quantiles[i] = interval_start + interval_frac * (
127                interval_end - interval_start
128            )
129
130        return quantiles
131
132    def sample(self, reps: int = 999) -> np.ndarray:
133        """Generate bootstrap replicates."""
134        if self.order_stats_ is None:
135            raise ValueError("Must call fit() before sample()")
136
137        T = len(self.order_stats_)
138        ensemble = np.zeros((T, reps))
139
140        for j in range(reps):
141            me_quantiles = self._generate_me_quantiles(T)
142            sorted_quantiles = np.sort(me_quantiles)
143
144            original_order_quantiles = np.zeros(T)
145            for i, idx in enumerate(self.ordering_index_):
146                original_order_quantiles[idx] = sorted_quantiles[i]
147
148            ensemble[:, j] = original_order_quantiles
149
150        return ensemble
151
152    # ==================== PLOTTING METHODS ====================
153
154    def plot_me_density(self, figsize: Tuple[int, int] = (12, 8)) -> plt.Figure:
155        """Plot the maximum entropy density with intervals."""
156        if self.order_stats_ is None:
157            raise ValueError("Must call fit() first")
158
159        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize)
160
161        # Plot 1: ME Density Intervals
162        z0, zT = self.limits_
163        all_z_points = np.concatenate([[z0], self.intermediate_points_, [zT]])
164
165        for i in range(len(all_z_points) - 1):
166            ax1.axvspan(
167                all_z_points[i],
168                all_z_points[i + 1],
169                alpha=0.3,
170                label=f"Interval {i+1}" if i == 0 else "",
171            )
172            ax1.axvline(all_z_points[i], color="red", linestyle="--", alpha=0.7)
173
174        ax1.axvline(all_z_points[-1], color="red", linestyle="--", alpha=0.7)
175        ax1.set_title("Maximum Entropy Density Intervals")
176        ax1.set_xlabel("Value")
177        ax1.set_ylabel("Intervals")
178        ax1.legend()
179
180        # Plot 2: Original vs Order Statistics
181        ax2.plot(
182            self.original_series_, "o-", label="Original Series", alpha=0.7
183        )
184        ax2.plot(self.order_stats_, "s-", label="Order Statistics", alpha=0.7)
185        ax2.set_title("Original Series vs Order Statistics")
186        ax2.set_xlabel("Index")
187        ax2.set_ylabel("Value")
188        ax2.legend()
189        ax2.grid(True, alpha=0.3)
190
191        plt.tight_layout()
192        return fig
193
194    def plot_bootstrap_ensemble(
195        self, reps: int = 50, figsize: Tuple[int, int] = (15, 10)
196    ) -> plt.Figure:
197        """Plot multiple bootstrap replicates with original series."""
198        ensemble = self.sample(reps)
199
200        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=figsize)
201
202        # Plot 1: All replicates
203        time_index = np.arange(len(self.original_series_))
204        for j in range(min(reps, 50)):  # Limit to 50 for clarity
205            ax1.plot(time_index, ensemble[:, j], alpha=0.1, color="blue")
206
207        ax1.plot(
208            time_index,
209            self.original_series_,
210            "r-",
211            linewidth=2,
212            label="Original",
213        )
214        ax1.set_title(f"ME Bootstrap Ensemble ({reps} replicates)")
215        ax1.set_xlabel("Time")
216        ax1.set_ylabel("Value")
217        ax1.legend()
218        ax1.grid(True, alpha=0.3)
219
220        # Plot 2: Mean and confidence intervals
221        mean_ensemble = np.mean(ensemble, axis=1)
222        ci_lower = np.percentile(ensemble, 2.5, axis=1)
223        ci_upper = np.percentile(ensemble, 97.5, axis=1)
224
225        ax2.fill_between(
226            time_index, ci_lower, ci_upper, alpha=0.3, label="95% CI"
227        )
228        ax2.plot(time_index, mean_ensemble, "b-", label="Ensemble Mean")
229        ax2.plot(time_index, self.original_series_, "r-", label="Original")
230        ax2.set_title("Ensemble Mean and 95% Confidence Intervals")
231        ax2.set_xlabel("Time")
232        ax2.set_ylabel("Value")
233        ax2.legend()
234        ax2.grid(True, alpha=0.3)
235
236        # Plot 3: Distribution at selected time points
237        if len(time_index) >= 5:
238            selected_times = np.linspace(0, len(time_index) - 1, 5, dtype=int)
239            for i, t in enumerate(selected_times):
240                ax3.hist(
241                    ensemble[t, :],
242                    bins=30,
243                    alpha=0.5,
244                    label=f"Time {t}",
245                    density=True,
246                )
247            ax3.set_title("Distribution at Selected Time Points")
248            ax3.set_xlabel("Value")
249            ax3.set_ylabel("Density")
250            ax3.legend()
251
252        plt.tight_layout()
253        return fig
254
255    def plot_sampling_distribution(
256        self,
257        statistic: Callable,
258        reps: int = 999,
259        figsize: Tuple[int, int] = (12, 10),
260    ) -> plt.Figure:
261        """Plot sampling distribution of a statistic."""
262        ensemble = self.sample(reps)
263
264        # Calculate statistic for each bootstrap sample
265        stats_boot = np.zeros(reps)
266        for j in range(reps):
267            stats_boot[j] = statistic(ensemble[:, j])
268
269        original_stat = statistic(self.original_series_)
270
271        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
272
273        # Histogram with KDE
274        ax1.hist(stats_boot, bins=30, density=True, alpha=0.7, color="skyblue")
275        ax1.axvline(
276            original_stat,
277            color="red",
278            linestyle="--",
279            linewidth=2,
280            label=f"Original: {original_stat:.3f}",
281        )
282        ax1.axvline(
283            np.mean(stats_boot),
284            color="green",
285            linestyle="--",
286            linewidth=2,
287            label=f"Bootstrap Mean: {np.mean(stats_boot):.3f}",
288        )
289        ax1.set_title("Bootstrap Sampling Distribution")
290        ax1.set_xlabel("Statistic Value")
291        ax1.set_ylabel("Density")
292        ax1.legend()
293        ax1.grid(True, alpha=0.3)
294
295        # Q-Q plot for normality check
296        stats.probplot(stats_boot, dist="norm", plot=ax2)
297        ax2.set_title("Q-Q Plot for Normality Check")
298
299        plt.tight_layout()
300        return fig, stats_boot
301
302    # ==================== HYPOTHESIS TESTING METHODS ====================
303
304    def hypothesis_test(
305        self,
306        statistic: Callable,
307        null_value: float = 0,
308        alternative: str = "two-sided",
309        reps: int = 999,
310        confidence: float = 0.95,
311    ) -> HypothesisTestResult:
312        """
313        Perform hypothesis test using ME bootstrap.
314
315        Parameters
316        ----------
317        statistic : callable
318            Function that computes the test statistic
319        null_value : float, default=0
320            Value under the null hypothesis
321        alternative : str, default='two-sided'
322            Alternative hypothesis: 'two-sided', 'less', or 'greater'
323        reps : int, default=999
324            Number of bootstrap replicates
325        confidence : float, default=0.95
326            Confidence level for interval
327
328        Returns
329        -------
330        HypothesisTestResult
331        """
332        if alternative not in ["two-sided", "less", "greater"]:
333            raise ValueError(
334                "Alternative must be 'two-sided', 'less', or 'greater'"
335            )
336
337        ensemble = self.sample(reps)
338
339        # Calculate statistic for each bootstrap sample
340        stats_boot = np.zeros(reps)
341        for j in range(reps):
342            stats_boot[j] = statistic(ensemble[:, j])
343
344        original_stat = statistic(self.original_series_)
345
346        # Calculate p-value based on alternative hypothesis
347        if alternative == "two-sided":
348            p_value = 2 * min(
349                np.mean(stats_boot <= null_value),
350                np.mean(stats_boot >= null_value),
351            )
352            ci_lower = np.percentile(stats_boot, (1 - confidence) / 2 * 100)
353            ci_upper = np.percentile(
354                stats_boot, (1 - (1 - confidence) / 2) * 100
355            )
356        elif alternative == "less":
357            p_value = np.mean(stats_boot <= null_value)
358            ci_lower = np.percentile(stats_boot, (1 - confidence) * 100)
359            ci_upper = np.inf
360        else:  # 'greater'
361            p_value = np.mean(stats_boot >= null_value)
362            ci_lower = -np.inf
363            ci_upper = np.percentile(stats_boot, confidence * 100)
364
365        reject_null = p_value < (1 - confidence)
366
367        return HypothesisTestResult(
368            statistic=original_stat,
369            p_value=p_value,
370            ci_lower=ci_lower,
371            ci_upper=ci_upper,
372            null_value=null_value,
373            alternative=alternative,
374            reject_null=reject_null,
375            test_type="bootstrap",
376        )
377
378    def test_mean(
379        self,
380        null_value: float = 0,
381        alternative: str = "two-sided",
382        reps: int = 999,
383        confidence: float = 0.95,
384    ) -> HypothesisTestResult:
385        """Test hypothesis about the mean."""
386        return self.hypothesis_test(
387            statistic=np.mean,
388            null_value=null_value,
389            alternative=alternative,
390            reps=reps,
391            confidence=confidence,
392        )
393
394    def test_median(
395        self,
396        null_value: float = 0,
397        alternative: str = "two-sided",
398        reps: int = 999,
399        confidence: float = 0.95,
400    ) -> HypothesisTestResult:
401        """Test hypothesis about the median."""
402        return self.hypothesis_test(
403            statistic=np.median,
404            null_value=null_value,
405            alternative=alternative,
406            reps=reps,
407            confidence=confidence,
408        )
409
410    def test_variance(
411        self,
412        null_value: float = 1,
413        alternative: str = "two-sided",
414        reps: int = 999,
415        confidence: float = 0.95,
416    ) -> HypothesisTestResult:
417        """Test hypothesis about the variance."""
418        return self.hypothesis_test(
419            statistic=np.var,
420            null_value=null_value,
421            alternative=alternative,
422            reps=reps,
423            confidence=confidence,
424        )
425
426    def test_correlation(
427        self,
428        y: np.ndarray,
429        null_value: float = 0,
430        alternative: str = "two-sided",
431        reps: int = 999,
432        confidence: float = 0.95,
433    ) -> HypothesisTestResult:
434        """Test hypothesis about correlation with another series."""
435        if len(y) != len(self.original_series_):
436            raise ValueError("y must have same length as original series")
437
438        def corr_statistic(x):
439            return np.corrcoef(x, y)[0, 1]
440
441        return self.hypothesis_test(
442            statistic=corr_statistic,
443            null_value=null_value,
444            alternative=alternative,
445            reps=reps,
446            confidence=confidence,
447        )
448
449    def compare_means(
450        self,
451        y: np.ndarray,
452        null_value: float = 0,
453        alternative: str = "two-sided",
454        reps: int = 999,
455        confidence: float = 0.95,
456    ) -> HypothesisTestResult:
457        """Test for difference in means between two series."""
458
459        def mean_diff_statistic(x):
460            return np.mean(x) - np.mean(y)
461
462        return self.hypothesis_test(
463            statistic=mean_diff_statistic,
464            null_value=null_value,
465            alternative=alternative,
466            reps=reps,
467            confidence=confidence,
468        )
469
470    # ==================== UTILITY METHODS ====================
471
472    def get_params(self) -> dict:
473        """Get parameters of the fitted ME bootstrap."""
474        return {
475            "order_stats": self.order_stats_,
476            "ordering_index": self.ordering_index_,
477            "intermediate_points": self.intermediate_points_,
478            "interval_means": self.interval_means_,
479            "limits": self.limits_,
480            "trim": self.trim,
481        }
482
483    def summary(self) -> pd.DataFrame:
484        """Generate summary statistics of the original series."""
485        if self.original_series_ is None:
486            raise ValueError("Must call fit() first")
487
488        x = self.original_series_
489        stats_dict = {
490            "n_observations": len(x),
491            "mean": np.mean(x),
492            "median": np.median(x),
493            "std_dev": np.std(x),
494            "variance": np.var(x),
495            "min": np.min(x),
496            "max": np.max(x),
497            "skewness": stats.skew(x),
498            "kurtosis": stats.kurtosis(x),
499        }
500
501        return pd.DataFrame([stats_dict])

Maximum Entropy Bootstrap for time series inference with plotting and hypothesis testing.

def fit( self, x: Union[numpy.ndarray, List, pandas.core.series.Series]) -> MaximumEntropyBootstrap:
 77    def fit(
 78        self, x: Union[np.ndarray, List, pd.Series]
 79    ) -> "MaximumEntropyBootstrap":
 80        """Fit the ME bootstrap to the time series."""
 81        x = np.asarray(x)
 82        if len(x) < 3:
 83            raise ValueError("Time series must have at least 3 observations")
 84
 85        self.original_series_ = x.copy()
 86        T = len(x)
 87
 88        # Step 1: Sort data and store ordering index
 89        self.ordering_index_ = np.argsort(x)
 90        self.order_stats_ = x[self.ordering_index_]
 91
 92        # Step 2: Compute intermediate points
 93        self.intermediate_points_ = (
 94            self.order_stats_[:-1] + self.order_stats_[1:]
 95        ) / 2
 96
 97        # Step 3: Compute limits for tails
 98        m_trim = self._calculate_trimmed_mean_diff(x)
 99        z0 = self.order_stats_[0] - m_trim
100        zT = self.order_stats_[-1] + m_trim
101
102        self.limits_ = (z0, zT)
103
104        # Step 4: Compute interval means
105        self.interval_means_ = self._calculate_interval_means(self.order_stats_)
106
107        return self

Fit the ME bootstrap to the time series.

class TsDistroSimulator:
 28class TsDistroSimulator:
 29    def __init__(
 30        self,
 31        kernel="rbf",
 32        backend="numpy",
 33        kde_kernel="gaussian",
 34        random_state=None,
 35        residual_sampling="bootstrap",
 36        block_size=None,
 37        gmm_components=3,
 38    ):
 39        self.kernel = kernel
 40        self.backend = backend
 41        self.random_state = random_state
 42        self.residual_sampling = residual_sampling
 43        self.block_size = block_size
 44        self.gmm_components = gmm_components
 45        self.kde_kernel = kde_kernel
 46        self.Y_ = None
 47        self.n_samples_ = None
 48
 49        if random_state is not None:
 50            np.random.seed(random_state)
 51            if JAX_AVAILABLE:
 52                key = jax.random.PRNGKey(random_state)
 53
 54        valid_sampling_methods = [
 55            "bootstrap",
 56            "kde",
 57            "gmm",
 58            "block-bootstrap",
 59            "me-bootstrap",
 60        ]
 61        if residual_sampling not in valid_sampling_methods:
 62            raise ValueError(
 63                f"residual_sampling must be one of {valid_sampling_methods}"
 64            )
 65
 66        if backend in ["gpu", "tpu"] and JAX_AVAILABLE:
 67            self._setup_jax_backend()
 68        elif backend in ["gpu", "tpu"] and not JAX_AVAILABLE:
 69            print("JAX not available. Falling back to NumPy backend.")
 70            self.backend = "numpy"
 71
 72        self.model = None
 73        self.residuals_ = None
 74        self.X_dist = None
 75        self.is_fitted = False
 76        self.best_params_ = None
 77        self.best_score_ = None
 78        self.kde_model_ = None
 79        self.gmm_model_ = None
 80
 81    def _setup_jax_backend(self):
 82        if not JAX_AVAILABLE:
 83            raise ImportError("JAX is required for GPU/TPU backend")
 84
 85        @jit
 86        def pairwise_sq_dists_jax(X1, X2):
 87            X1_sq = jnp.sum(X1**2, axis=1)[:, jnp.newaxis]
 88            X2_sq = jnp.sum(X2**2, axis=1)[jnp.newaxis, :]
 89            return X1_sq + X2_sq - 2 * X1 @ X2.T
 90
 91        @jit
 92        def cdist_jax(X1, X2):
 93            return vmap(
 94                lambda x: vmap(lambda y: jnp.sqrt(jnp.sum((x - y) ** 2)))(X2)
 95            )(X1)
 96
 97        self._pairwise_sq_dists_jax = pairwise_sq_dists_jax
 98        self._cdist_jax = cdist_jax
 99
100    def _create_model(self, gamma, alpha, lags=20, n_hidden_features=5):
101        return ns.MTS(
102            obj=KernelRidge(kernel=self.kernel, gamma=gamma, alpha=alpha),
103            lags=lags,
104            n_hidden_features=n_hidden_features,
105        )
106
107    def _fit_residual_sampler(self, **kwargs):
108        if self.residuals_ is None or len(self.residuals_) == 0:
109            raise ValueError("No residuals available for fitting sampler")
110
111        if self.residual_sampling == "kde":
112            kernel_bandwidths = {"bandwidth": np.logspace(-6, 6, 150)}
113            grid = GridSearchCV(
114                KernelDensity(kernel=self.kde_kernel, **kwargs),
115                param_grid=kernel_bandwidths,
116            )
117            grid.fit(self.residuals_)
118            self.kde_model_ = grid.best_estimator_
119            self.kde_model_.fit(self.residuals_)
120
121        elif self.residual_sampling == "gmm":
122            self.gmm_model_ = GaussianMixture(
123                n_components=min(self.gmm_components, len(self.residuals_)),
124                random_state=self.random_state,
125                covariance_type="full",
126            )
127            self.gmm_model_.fit(self.residuals_)
128
129    def _sample_residuals(self, num_samples, random_state=123):
130        if self.residuals_ is None:
131            raise ValueError("No residuals available for sampling")
132
133        n = len(self.residuals_)
134
135        if self.residual_sampling == "bootstrap":
136            np.random.seed(random_state)
137            if num_samples <= n:
138                idx = np.random.choice(n, num_samples, replace=True)
139                return self.residuals_[idx]
140            else:
141                n_repeats = (num_samples // n) + 1
142                tiled = np.tile(self.residuals_, (n_repeats, 1))
143                idx = np.random.choice(len(tiled), num_samples, replace=False)
144                return tiled[idx]
145
146        elif self.residual_sampling == "kde":
147            if self.kde_model_ is None:
148                raise ValueError(
149                    "KDE model not fitted. Call _fit_residual_sampler first."
150                )
151
152            samples = self.kde_model_.sample(
153                num_samples, random_state=random_state
154            )
155
156            if samples.ndim == 1:
157                samples = samples.reshape(-1, 1)
158            return samples
159
160        elif self.residual_sampling == "gmm":
161            if self.gmm_model_ is None:
162                raise ValueError(
163                    "GMM model not fitted. Call _fit_residual_sampler first."
164                )
165
166            # Set random state before sampling
167            np.random.seed(random_state)
168            samples = self.gmm_model_.sample(num_samples)[0]
169
170            if samples.ndim == 1:
171                samples = samples.reshape(-1, 1)
172            return samples
173
174        elif self.residual_sampling == "me-bootstrap":
175            meb = MaximumEntropyBootstrap(random_state=random_state)
176            residuals = self.residuals_.flatten()
177            if residuals.shape[0] < num_samples:
178                repeats = int(np.ceil(num_samples / residuals.shape[0]))
179                residuals = np.tile(residuals, repeats)[:num_samples]
180            else:
181                residuals = residuals[:num_samples]
182            meb.fit(residuals)
183            samples = meb.sample(1)[:, 0].reshape(-1, 1)
184            # Ensure we have exactly num_samples
185            if len(samples) < num_samples:
186                n_repeats = (num_samples // len(samples)) + 1
187                samples = np.tile(samples, (n_repeats, 1))[:num_samples]
188            return samples
189
190        elif self.residual_sampling == "block-bootstrap":
191            samples = bootstrap(
192                self.residuals_,
193                num_samples,
194                block_size=self.block_size,
195                seed=random_state,
196            )
197            # Ensure correct shape
198            if samples.ndim == 1:
199                samples = samples.reshape(-1, 1)
200            return samples
201
202        else:
203            raise ValueError(
204                f"Unknown sampling method: {self.residual_sampling}"
205            )
206
207    def _pairwise_sq_dists(self, X1, X2):
208        if self.backend in ["gpu", "tpu"] and JAX_AVAILABLE:
209            X1_jax = jnp.array(X1)
210            X2_jax = jnp.array(X2)
211            result = self._pairwise_sq_dists_jax(X1_jax, X2_jax)
212            return np.array(result)
213        else:
214            X1 = np.atleast_2d(X1)
215            X2 = np.atleast_2d(X2)
216            return (
217                np.sum(X1**2, axis=1)[:, np.newaxis]
218                + np.sum(X2**2, axis=1)[np.newaxis, :]
219                - 2 * X1 @ X2.T
220            )
221
222    def _mmd(self, u, v, kernel_sigma=1):
223        if u.ndim == 1:
224            u = u.reshape(-1, 1)
225        if v.ndim == 1:
226            v = v.reshape(-1, 1)
227
228        def kmat(A, B):
229            return np.exp(
230                -self._pairwise_sq_dists(A, B) / (2 * kernel_sigma**2)
231            )
232
233        return (
234            np.mean(kmat(u, u)) + np.mean(kmat(v, v)) - 2 * np.mean(kmat(u, v))
235        )
236
237    def _custom_energy_distance(self, u, v):
238        if u.ndim == 1:
239            u = u.reshape(-1, 1)
240        if v.ndim == 1:
241            v = v.reshape(-1, 1)
242
243        n, d = u.shape
244        m = v.shape[0]
245
246        if self.backend in ["gpu", "tpu"] and JAX_AVAILABLE:
247            u_jax = jnp.array(u)
248            v_jax = jnp.array(v)
249            dist_xx = self._cdist_jax(u_jax, u_jax)
250            dist_yy = self._cdist_jax(v_jax, v_jax)
251            dist_xy = self._cdist_jax(u_jax, v_jax)
252            term1 = 2 * jnp.sum(dist_xy) / (n * m)
253            term2 = jnp.sum(dist_xx) / (n * n)
254            term3 = jnp.sum(dist_yy) / (m * m)
255            return float(term1 - term2 - term3)
256        else:
257            dist_xx = cdist(u, u, metric="euclidean")
258            dist_yy = cdist(v, v, metric="euclidean")
259            dist_xy = cdist(u, v, metric="euclidean")
260            term1 = 2 * np.sum(dist_xy) / (n * m)
261            term2 = np.sum(dist_xx) / (n * n)
262            term3 = np.sum(dist_yy) / (m * m)
263            return term1 - term2 - term3
264
265    def _generate_pseudo_single(self, random_state=123):
266        """
267        Generate a single synthetic realization.
268
269        Returns original data (structure) + resampled residuals (noise)
270        Each call produces a different realization due to different residual samples.
271        """
272        if not self.is_fitted:
273            raise ValueError("Model not fitted. Call fit() first.")
274
275        n_rows = self.n_samples_
276
277        # Base: original time series structure
278        base = self.Y_.copy()
279        if base.ndim == 1:
280            base = base.reshape(-1, 1)
281
282        # Noise: sample new residuals from learned distribution
283        residuals = self._sample_residuals(n_rows, random_state)
284
285        # Ensure residuals match the size of base
286        # This is needed because model residuals may be shorter due to lags
287        if residuals.shape[0] < n_rows:
288            n_repeats = (n_rows // residuals.shape[0]) + 1
289            residuals = np.tile(residuals, (n_repeats, 1))[:n_rows]
290        elif residuals.shape[0] > n_rows:
291            residuals = residuals[:n_rows]
292
293        # Return: structure + new noise realization
294        return base + residuals
295
296    def fit(self, Y, metric="energy", n_trials=50, **kwargs):
297        if Y.ndim == 1:
298            Y = Y.reshape(-1, 1)
299
300        n, d = Y.shape
301        self.n_features_ = d
302        self.n_samples_ = n
303        self.Y_ = Y  # Store once before optimization
304
305        self.X_dist = np.random.normal(0, 1, (n, d))
306
307        def objective(trial):
308            sigma = trial.suggest_float("sigma", 0.01, 10, log=True)
309            lambd = trial.suggest_float("lambd", 1e-5, 1, log=True)
310            lags = trial.suggest_int("lags", 1, 50)
311            n_hidden_features = trial.suggest_int("n_hidden_features", 1, 20)
312            gamma = 1 / (2 * sigma**2)
313
314            model = self._create_model(gamma, lambd, lags, n_hidden_features)
315            model.fit(Y)
316
317            # Generate synthetic sample using this model's residuals
318            Y_sim = self._generate_pseudo_with_model(
319                model, model.residuals_, n, random_state=trial.number
320            )
321
322            if metric == "energy":
323                dist_val = self._custom_energy_distance(Y, Y_sim)
324            elif metric == "mmd":
325                dist_val = self._mmd(Y, Y_sim)
326            elif metric == "wasserstein" and d == 1:
327                dist_val = stats.wasserstein_distance(
328                    Y.flatten(), Y_sim.flatten()
329                )
330            else:
331                raise ValueError("Invalid metric for dimension")
332
333            return dist_val
334
335        study = optuna.create_study(direction="minimize")
336        study.optimize(objective, n_trials=n_trials, **kwargs)
337
338        self.best_params_ = study.best_params
339        self.best_score_ = study.best_value
340        sigma = self.best_params_["sigma"]
341        lambd = self.best_params_["lambd"]
342        lags = self.best_params_["lags"]
343        n_hidden_features = self.best_params_["n_hidden_features"]
344        gamma = 1 / (2 * sigma**2)
345
346        self.model = self._create_model(gamma, lambd, lags, n_hidden_features)
347        self.model.fit(Y)
348
349        self.residuals_ = self.model.residuals_
350
351        self._fit_residual_sampler()
352        self.is_fitted = True
353
354        print(f"  Best energy distance: {self.best_score_:.6f}")
355        print(f"  Best lags: {lags}, n_hidden_features: {n_hidden_features}")
356
357        return self
358
359    def _generate_pseudo_with_model(
360        self, model, residuals, num_samples, random_state=None
361    ):
362        """Helper function for optimization - temporarily uses different residuals"""
363        # Temporarily store and swap residual models
364        original_residuals = self.residuals_
365        original_kde = self.kde_model_
366        original_gmm = self.gmm_model_
367
368        self.residuals_ = residuals
369
370        # Only fit if using kde or gmm
371        if self.residual_sampling in ["kde", "gmm"]:
372            self._fit_residual_sampler()
373
374        # Length of actual residuals from the model
375        residual_len = len(residuals)
376
377        # Use provided random state or generate one
378        if random_state is None:
379            random_state = np.random.randint(0, 10000)
380
381        # Sample residuals matching the residual length
382        sampled_residuals = self._sample_residuals(
383            residual_len, random_state=random_state
384        )
385
386        # Restore original state
387        self.residuals_ = original_residuals
388        self.kde_model_ = original_kde
389        self.gmm_model_ = original_gmm
390
391        # The model with lags produces residuals shorter than original data
392        # We need to align: use only the portion of Y that corresponds to residuals
393        # Typically, if lags=L, residuals start from index L
394        y_slice = self.Y_[-residual_len:]  # Take the last residual_len points
395
396        # Return: aligned data + resampled residuals
397        return y_slice + sampled_residuals
398
399    def sample(self, n_samples=1):
400        """
401        Generate synthetic samples via distribution matching.
402
403        Each sample is: original structure + resampled residuals
404
405        Parameters:
406        -----------
407        n_samples : int, default=1
408            Number of synthetic realizations to generate
409
410        Returns:
411        --------
412        samples : ndarray
413            - If Y was univariate (n_rows, 1): returns shape (n_rows, n_samples)
414            - If Y was multivariate (n_rows, n_features): returns shape (n_features, n_rows, n_samples)
415        """
416        if not self.is_fitted:
417            raise ValueError("Model not fitted. Call fit() first.")
418
419        # Generate n_samples realizations, each with shape (n_rows, n_features)
420        samples_list = []
421        for i, _ in enumerate(range(n_samples)):
422            sample = self._generate_pseudo_single(
423                random_state=1000 + i
424            )  # Shape: (n_rows, n_features)
425            samples_list.append(sample)
426
427        # Stack to get shape (n_samples, n_rows, n_features)
428        stacked = np.stack(samples_list, axis=0)
429
430        # If univariate (n_features == 1), return (n_rows, n_samples)
431        if self.n_features_ == 1:
432            result = stacked.squeeze(
433                axis=2
434            ).T  # (n_samples, n_rows) -> (n_rows, n_samples)
435        else:
436            # If multivariate, return (n_features, n_rows, n_samples)
437            result = stacked.transpose(
438                2, 1, 0
439            )  # (n_samples, n_rows, n_features) -> (n_features, n_rows, n_samples)
440
441        return result
442
443    def compare_distributions(self, Y_orig, Y_sim, save_prefix=""):
444        """
445        Visual comparison of original and synthetic distributions.
446
447        Parameters:
448        -----------
449        Y_orig : array-like
450            Original data
451        Y_sim : array-like
452            Synthetic data
453        save_prefix : str, default=''
454            Prefix for saving plots
455        """
456        if Y_orig.ndim == 1:
457            Y_orig = Y_orig.reshape(-1, 1)
458        if Y_sim.ndim == 1:
459            Y_sim = Y_sim.reshape(-1, 1)
460
461        n, d = Y_orig.shape
462
463        # Create a figure with subplots for statistical tests
464        fig, axes = plt.subplots(2, d, figsize=(6 * d, 10))
465        if d == 1:
466            axes = axes.reshape(2, 1)
467
468        # Statistical test results storage
469        ks_results = []
470        ad_results = []
471
472        for i in range(d):
473            # Top row: Histograms with statistical test annotations
474            ax_hist = axes[0, i]
475
476            # Plot histograms
477            ax_hist.hist(
478                Y_orig[:, i],
479                alpha=0.5,
480                label="Original",
481                density=True,
482                bins=20,
483                color="blue",
484            )
485            ax_hist.hist(
486                Y_sim[:, i],
487                alpha=0.5,
488                label="Simulated",
489                density=True,
490                bins=20,
491                color="red",
492            )
493
494            # Perform statistical tests
495            ks_stat, ks_pvalue = stats.ks_2samp(Y_orig[:, i], Y_sim[:, i])
496            ks_results.append((ks_stat, ks_pvalue))
497
498            ad_result = stats.anderson_ksamp([Y_orig[:, i], Y_sim[:, i]])
499            ad_stat = ad_result.statistic
500            ad_critical = ad_result.critical_values
501            ad_significance = ad_result.significance_level
502            ad_results.append((ad_stat, ad_significance))
503
504            # Add test results to histogram plot
505            textstr = "\n".join(
506                (
507                    f"KS test: p = {ks_pvalue:.4f}",
508                    f"AD test: p < {ad_significance:.3f}",
509                    f"AD stat: {ad_stat:.4f}",
510                )
511            )
512            props = dict(boxstyle="round", facecolor="wheat", alpha=0.8)
513            ax_hist.text(
514                0.05,
515                0.95,
516                textstr,
517                transform=ax_hist.transAxes,
518                fontsize=10,
519                verticalalignment="top",
520                bbox=props,
521            )
522
523            ax_hist.legend()
524            ax_hist.set_title(
525                f"Dimension {i+1} - Histograms with Statistical Tests"
526            )
527            ax_hist.set_xlabel("Value")
528            ax_hist.set_ylabel("Density")
529
530            # Bottom row: ECDFs with KS test visualization
531            ax_ecdf = axes[1, i]
532
533            # Compute ECDFs
534            sorted_orig = np.sort(Y_orig[:, i])
535            ecdf_orig = np.arange(1, len(sorted_orig) + 1) / len(sorted_orig)
536            sorted_sim = np.sort(Y_sim[:, i])
537            ecdf_sim = np.arange(1, len(sorted_sim) + 1) / len(sorted_sim)
538
539            # Plot ECDFs
540            ax_ecdf.step(
541                sorted_orig,
542                ecdf_orig,
543                label="Original",
544                color="blue",
545                linewidth=2,
546            )
547            ax_ecdf.step(
548                sorted_sim,
549                ecdf_sim,
550                label="Simulated",
551                color="red",
552                linewidth=2,
553            )
554
555            # Find the point of maximum difference for KS test
556            all_values = np.sort(np.concatenate([sorted_orig, sorted_sim]))
557            ecdf_orig_all = np.searchsorted(
558                sorted_orig, all_values, side="right"
559            ) / len(sorted_orig)
560            ecdf_sim_all = np.searchsorted(
561                sorted_sim, all_values, side="right"
562            ) / len(sorted_sim)
563            diff = np.abs(ecdf_orig_all - ecdf_sim_all)
564            max_idx = np.argmax(diff)
565            max_x = all_values[max_idx]
566            max_y1 = ecdf_orig_all[max_idx]
567            max_y2 = ecdf_sim_all[max_idx]
568
569            # Mark the maximum difference point
570            ax_ecdf.plot(
571                [max_x, max_x],
572                [max_y1, max_y2],
573                "k-",
574                linewidth=3,
575                label=f"KS stat: {ks_stat:.4f}",
576            )
577            ax_ecdf.plot(max_x, max_y1, "ko", markersize=8)
578            ax_ecdf.plot(max_x, max_y2, "ko", markersize=8)
579
580            ax_ecdf.legend()
581            ax_ecdf.set_title(f"Dimension {i+1} - ECDFs with KS Statistic")
582            ax_ecdf.set_xlabel("Value")
583            ax_ecdf.set_ylabel("ECDF")
584
585        plt.tight_layout()
586        if save_prefix:
587            plt.savefig(
588                f"{save_prefix}_statistical_comparison.png",
589                dpi=300,
590                bbox_inches="tight",
591            )
592        plt.show()
593
594        # Print comprehensive test results
595        print("\n" + "=" * 60)
596        print("COMPREHENSIVE STATISTICAL TEST RESULTS")
597        print("=" * 60)
598
599        for i in range(d):
600            ks_stat, ks_pvalue = ks_results[i]
601            ad_stat, ad_significance = ad_results[i]
602
603            print(f"\nDimension {i+1}:")
604            print(f"  Kolmogorov-Smirnov Test:")
605            print(f"    Statistic: {ks_stat:.6f}")
606            print(f"    p-value: {ks_pvalue:.6f}")
607            print(
608                f"    Significance: {'Not Significant' if ks_pvalue > 0.05 else 'SIGNIFICANT'}"
609            )
610
611            print(f"  Anderson-Darling Test:")
612            print(f"    Statistic: {ad_stat:.6f}")
613            print(f"    Significance level: {ad_significance:.3f}")
614            print(
615                f"    Interpretation: {'Distributions differ' if ad_stat > ad_result.critical_values[2] else 'Distributions similar'}"
616            )
617
618        # Q-Q plots for each dimension
619        fig, axes = plt.subplots(1, d, figsize=(5 * d, 5))
620        if d == 1:
621            axes = [axes]
622
623        for i in range(d):
624            orig_sorted = np.sort(Y_orig[:, i])
625            sim_sorted = np.sort(Y_sim[:, i])
626
627            n_orig = len(orig_sorted)
628            n_sim = len(sim_sorted)
629
630            n_points = min(n_orig, n_sim, 1000)
631            quantiles = np.linspace(0, 1, n_points)
632
633            orig_quantiles = np.quantile(orig_sorted, quantiles)
634            sim_quantiles = np.quantile(sim_sorted, quantiles)
635
636            axes[i].plot(
637                orig_quantiles, sim_quantiles, "o", alpha=0.6, markersize=3
638            )
639            min_val = min(orig_quantiles.min(), sim_quantiles.min())
640            max_val = max(orig_quantiles.max(), sim_quantiles.max())
641            axes[i].plot(
642                [min_val, max_val],
643                [min_val, max_val],
644                "r--",
645                alpha=0.8,
646                linewidth=2,
647            )
648            axes[i].set_xlabel("Original Data Quantiles")
649            axes[i].set_ylabel("Simulated Data Quantiles")
650            axes[i].set_title(f"Dimension {i+1} - Q-Q Plot")
651
652            corr = np.corrcoef(orig_quantiles, sim_quantiles)[0, 1]
653            axes[i].text(
654                0.05,
655                0.95,
656                f"Corr: {corr:.4f}",
657                transform=axes[i].transAxes,
658                bbox=dict(
659                    boxstyle="round,pad=0.3", facecolor="white", alpha=0.8
660                ),
661                verticalalignment="top",
662            )
663
664        plt.tight_layout()
665        if save_prefix:
666            plt.savefig(
667                f"{save_prefix}_qq_plots.png", dpi=300, bbox_inches="tight"
668            )
669        plt.show()
670
671        return {
672            "ks_results": ks_results,
673            "ad_results": ad_results,
674            "dimensions": d,
675        }
def fit(self, Y, metric='energy', n_trials=50, **kwargs):
296    def fit(self, Y, metric="energy", n_trials=50, **kwargs):
297        if Y.ndim == 1:
298            Y = Y.reshape(-1, 1)
299
300        n, d = Y.shape
301        self.n_features_ = d
302        self.n_samples_ = n
303        self.Y_ = Y  # Store once before optimization
304
305        self.X_dist = np.random.normal(0, 1, (n, d))
306
307        def objective(trial):
308            sigma = trial.suggest_float("sigma", 0.01, 10, log=True)
309            lambd = trial.suggest_float("lambd", 1e-5, 1, log=True)
310            lags = trial.suggest_int("lags", 1, 50)
311            n_hidden_features = trial.suggest_int("n_hidden_features", 1, 20)
312            gamma = 1 / (2 * sigma**2)
313
314            model = self._create_model(gamma, lambd, lags, n_hidden_features)
315            model.fit(Y)
316
317            # Generate synthetic sample using this model's residuals
318            Y_sim = self._generate_pseudo_with_model(
319                model, model.residuals_, n, random_state=trial.number
320            )
321
322            if metric == "energy":
323                dist_val = self._custom_energy_distance(Y, Y_sim)
324            elif metric == "mmd":
325                dist_val = self._mmd(Y, Y_sim)
326            elif metric == "wasserstein" and d == 1:
327                dist_val = stats.wasserstein_distance(
328                    Y.flatten(), Y_sim.flatten()
329                )
330            else:
331                raise ValueError("Invalid metric for dimension")
332
333            return dist_val
334
335        study = optuna.create_study(direction="minimize")
336        study.optimize(objective, n_trials=n_trials, **kwargs)
337
338        self.best_params_ = study.best_params
339        self.best_score_ = study.best_value
340        sigma = self.best_params_["sigma"]
341        lambd = self.best_params_["lambd"]
342        lags = self.best_params_["lags"]
343        n_hidden_features = self.best_params_["n_hidden_features"]
344        gamma = 1 / (2 * sigma**2)
345
346        self.model = self._create_model(gamma, lambd, lags, n_hidden_features)
347        self.model.fit(Y)
348
349        self.residuals_ = self.model.residuals_
350
351        self._fit_residual_sampler()
352        self.is_fitted = True
353
354        print(f"  Best energy distance: {self.best_score_:.6f}")
355        print(f"  Best lags: {lags}, n_hidden_features: {n_hidden_features}")
356
357        return self
class DiversityGenerator:
  7class DiversityGenerator:
  8    """
  9    Three-step Gaussian Copula transformation for controlled diversity generation
 10    while preserving marginal distributions.
 11    """
 12
 13    def __init__(
 14        self, target_correlation=0.1, preserve_moments=True, random_state=None
 15    ):
 16        self.target_correlation = target_correlation
 17        self.preserve_moments = preserve_moments
 18        self.random_state = random_state
 19        if random_state is not None:
 20            np.random.seed(random_state)
 21        self.fitted_ = False  # Initialize fitted_ attribute
 22
 23    def fit(self, X):
 24        """
 25        STEP 1: Learn ECDFs and create target correlation matrix
 26        """
 27        X = np.asarray(X)
 28        self.n_samples_, self.n_features_ = X.shape
 29        self.original_dtype_ = X.dtype
 30
 31        # Store original statistics for moment preservation
 32        self.original_means_ = np.mean(X, axis=0)
 33        self.original_stds_ = np.std(X, axis=0)
 34
 35        # Store ECDF information for inverse transformation
 36        self.sorted_columns_ = [
 37            np.sort(X[:, j]) for j in range(self.n_features_)
 38        ]
 39        self.quantile_positions_ = (np.arange(1, self.n_samples_ + 1)) / (
 40            self.n_samples_ + 1
 41        )
 42
 43        # Create target correlation matrix
 44        self.target_corr_matrix_ = self._create_target_correlation_matrix()
 45
 46        # Precompute Cholesky decomposition for correlation application
 47        try:
 48            self.cholesky_factor_ = np.linalg.cholesky(self.target_corr_matrix_)
 49        except np.linalg.LinAlgError:
 50            self.target_corr_matrix_ = self._nearest_positive_definite(
 51                self.target_corr_matrix_
 52            )
 53            self.cholesky_factor_ = np.linalg.cholesky(self.target_corr_matrix_)
 54
 55        self.fitted_ = True
 56        return self
 57
 58    def _create_target_correlation_matrix(self):
 59        """Create valid target correlation matrix"""
 60        if isinstance(self.target_correlation, (int, float)):
 61            corr_val = float(self.target_correlation)
 62            corr_val = np.clip(corr_val, -1.0 / (self.n_features_ - 1), 1.0)
 63
 64            corr_matrix = np.full(
 65                (self.n_features_, self.n_features_), corr_val
 66            )
 67            np.fill_diagonal(corr_matrix, 1.0)
 68
 69        elif isinstance(self.target_correlation, np.ndarray):
 70            corr_matrix = self.target_correlation.copy()
 71            np.fill_diagonal(corr_matrix, 1.0)
 72        else:
 73            raise ValueError("target_correlation must be scalar or matrix")
 74
 75        return corr_matrix
 76
 77    def _nearest_positive_definite(self, matrix):
 78        """Ensure matrix is positive definite"""
 79        n = matrix.shape[0]
 80        matrix = (matrix + matrix.T) / 2
 81
 82        min_eigval = np.min(np.linalg.eigvals(matrix))
 83        if min_eigval > 0:
 84            return matrix
 85
 86        identity = np.eye(n)
 87        for k in range(1, 1000):
 88            candidate = matrix + k * 1e-8 * identity
 89            if np.min(np.linalg.eigvals(candidate)) > 0:
 90                return candidate
 91
 92        return np.eye(n)
 93
 94    def transform_to_gaussian(self, X):
 95        """
 96        STEP 2: Transform X → ranks → uniform → Gaussian
 97        """
 98        X = np.asarray(X)
 99        n_new = X.shape[0]
100        Y = np.zeros((n_new, self.n_features_), dtype=float)
101
102        for j in range(self.n_features_):
103            sorted_vals = self.sorted_columns_[j]
104
105            # X → ranks → uniform
106            empirical_cdf = np.searchsorted(
107                sorted_vals, X[:, j], side="right"
108            ) / (self.n_samples_ + 1)
109            empirical_cdf = np.clip(empirical_cdf, 0.001, 0.999)
110
111            # uniform → Gaussian (probit transform)
112            Y[:, j] = norm.ppf(empirical_cdf)
113
114        return Y
115
116    def apply_target_correlation(self, Y):
117        """
118        STEP 2 (continued): Apply target correlation to Gaussian data
119        """
120        return Y @ self.cholesky_factor_.T
121
122    def transform_from_gaussian(self, Y_transformed):
123        """
124        STEP 3: Transform Gaussian → uniform → inverse ECDF → X_diverse
125        """
126        n_new = Y_transformed.shape[0]
127        Z = np.zeros((n_new, self.n_features_), dtype=float)
128
129        for j in range(self.n_features_):
130            sorted_vals = self.sorted_columns_[j]
131
132            # Gaussian → uniform
133            U_transformed = norm.cdf(Y_transformed[:, j])
134
135            # uniform → inverse ECDF → X_diverse
136            Z[:, j] = np.interp(
137                U_transformed, self.quantile_positions_, sorted_vals
138            )
139
140            # Optional moment preservation
141            if self.preserve_moments:
142                Z[:, j] = self._preserve_moments(Z[:, j], j)
143
144        return Z.astype(self.original_dtype_)
145
146    def _preserve_moments(self, values, feature_idx):
147        """Preserve mean and standard deviation if needed"""
148        current_mean = np.mean(values)
149        current_std = np.std(values)
150
151        target_mean = self.original_means_[feature_idx]
152        target_std = self.original_stds_[feature_idx]
153
154        mean_ratio = abs(current_mean - target_mean) / (
155            abs(target_mean) + 1e-10
156        )
157        std_ratio = abs(current_std - target_std) / (target_std + 1e-10)
158
159        if mean_ratio > 0.02 or std_ratio > 0.05:
160            values_centered = values - current_mean
161            if current_std > 1e-10:
162                values_scaled = values_centered * (target_std / current_std)
163            else:
164                values_scaled = values_centered
165            return values_scaled + target_mean
166
167        return values
168
169    def generate_diverse_samples(self, X, n_samples=5):
170        """Generate diverse samples using the three-step pipeline"""
171        if not self.fitted_:
172            self.fit(X)
173
174        diverse_samples = []
175
176        for i in range(n_samples):
177            if i == 0:
178                # First sample: transform original data
179                Y_gaussian = self.transform_to_gaussian(X)
180            else:
181                # Additional samples: generate new Gaussian data
182                Y_gaussian = np.random.normal(
183                    0, 1, (self.n_samples_, self.n_features_)
184                )
185
186            # Apply target correlation
187            Y_diverse = self.apply_target_correlation(Y_gaussian)
188
189            # Transform back to original distributions
190            X_diverse = self.transform_from_gaussian(Y_diverse)
191            diverse_samples.append(X_diverse)
192
193        return np.array(diverse_samples)
194
195    def fit_transform(self, X, n_samples=5):
196        """Fit and generate diverse samples in one call"""
197        return self.generate_diverse_samples(X, n_samples)

Three-step Gaussian Copula transformation for controlled diversity generation while preserving marginal distributions.

def fit(self, X):
23    def fit(self, X):
24        """
25        STEP 1: Learn ECDFs and create target correlation matrix
26        """
27        X = np.asarray(X)
28        self.n_samples_, self.n_features_ = X.shape
29        self.original_dtype_ = X.dtype
30
31        # Store original statistics for moment preservation
32        self.original_means_ = np.mean(X, axis=0)
33        self.original_stds_ = np.std(X, axis=0)
34
35        # Store ECDF information for inverse transformation
36        self.sorted_columns_ = [
37            np.sort(X[:, j]) for j in range(self.n_features_)
38        ]
39        self.quantile_positions_ = (np.arange(1, self.n_samples_ + 1)) / (
40            self.n_samples_ + 1
41        )
42
43        # Create target correlation matrix
44        self.target_corr_matrix_ = self._create_target_correlation_matrix()
45
46        # Precompute Cholesky decomposition for correlation application
47        try:
48            self.cholesky_factor_ = np.linalg.cholesky(self.target_corr_matrix_)
49        except np.linalg.LinAlgError:
50            self.target_corr_matrix_ = self._nearest_positive_definite(
51                self.target_corr_matrix_
52            )
53            self.cholesky_factor_ = np.linalg.cholesky(self.target_corr_matrix_)
54
55        self.fitted_ = True
56        return self

STEP 1: Learn ECDFs and create target correlation matrix