Example usage
## Loading required package: MASS
## Loading required package: foreach
## Loading required package: doSNOW
## Loading required package: iterators
## Loading required package: snow
##
## Attaching package: 'misc'
## The following object is masked from 'package:stats':
##
## integrate
# Prepare test data
data(mtcars)
set.seed(42)
train_idx <- sample(nrow(mtcars), 20)
X_train <- mtcars[train_idx, c("wt", "hp", "qsec")]
y_train <- mtcars[train_idx, "mpg"]
X_test <- mtcars[-train_idx, c("wt", "hp", "qsec")]
y_test <- mtcars[-train_idx, "mpg"]
# Let's manually verify with a simple example
simple_X <- X_test[1:2, ]
result_debug <- fit_predict_shap(X_train, y_train, simple_X,
model_type = "lm", center_response = FALSE)
cat("Simple example (no response centering):\n")
## Simple example (no response centering):
cat("Coefficients:", round(coef(result_debug$model), 4), "\n")
## Coefficients: 25.723 -3.7499 -0.0185 0.5064
cat("Feature means:", round(result_debug$feature_means, 4), "\n")
## Feature means: 3.4412 157.05 17.8145
cat("Expected prediction at means:", round(result_debug$base_value, 4), "\n")
## Expected prediction at means: 18.94
## Test data:
## wt hp qsec
## Mazda RX4 Wag 2.875 110 17.02
## Valiant 3.460 105 20.22
cat("Predictions:", round(result_debug$predictions, 4), "\n")
## Predictions: 21.5299 21.0492
cat("Base value:", round(result_debug$base_value, 4), "\n")
## Base value: 18.94
## SHAP values:
## wt hp qsec
## Mazda RX4 Wag 2.1232 0.8691 -0.4024
## Valiant -0.0705 0.9614 1.2183
cat("Base + SHAP sum:", round(result_debug$base_value + rowSums(result_debug$shap_values), 4), "\n")
## Base + SHAP sum: 21.5299 21.0492
cat("Residuals:", round(result_debug$validation$residuals, 4), "\n")
## Residuals: 0 0
# Manual calculation
coefs <- coef(result_debug$model)
intercept <- coefs["(Intercept)"]
slopes <- coefs[names(coefs) != "(Intercept)"]
manual_pred <- intercept + as.matrix(simple_X) %*% slopes
manual_base <- intercept + sum(slopes * result_debug$feature_means)
cat("Manual prediction:", round(as.numeric(manual_pred), 4), "\n")
## Manual prediction: 21.5299 21.0492
## Manual base value: 18.94
# =============================================================================
# ORIGINAL EXAMPLE USAGE AND TESTING
# =============================================================================
cat("=== Testing Different Model Types ===\n\n")
## === Testing Different Model Types ===
# Test 1: Standard linear model
cat("1. Linear Model (lm):\n")
## 1. Linear Model (lm):
## === SHAP Analysis Summary ===
## Model type: lm
## Response centered: TRUE
## Correlation-aware: FALSE
## Base value: 18.94
## Validation passed: TRUE
## Max residual: 3.552714e-15
##
## Feature means:
## wt hp qsec
## 3.4412 157.0500 17.8145
##
## SHAP values (first few observations):
## wt hp qsec
## Mazda RX4 Wag 2.1232 0.8691 -0.4024
## Valiant -0.0705 0.9614 1.2183
## Merc 240D 0.9420 1.7557 1.1068
## Merc 450SE -2.3579 -0.4239 -0.2099
## Merc 450SL -1.0830 -0.4239 -0.1086
## Honda Civic 6.8480 1.9404 0.3573
# Debug: Let's manually check the decomposition for lm
cat("\nManual verification for lm:\n")
##
## Manual verification for lm:
manual_pred <- predict(result_lm$model, newdata = X_test)
cat("Model predictions:", round(head(manual_pred + result_lm$response_mean), 4), "\n")
## Model predictions: 21.5299 21.0492 22.7445 15.9482 17.3245 28.0858
cat("Our predictions:", round(head(result_lm$predictions), 4), "\n")
## Our predictions: 21.5299 21.0492 22.7445 15.9482 17.3245 28.0858
cat("Base + sum(SHAP):", round(head(result_lm$base_value + rowSums(result_lm$shap_values)), 4), "\n")
## Base + sum(SHAP): 21.5299 21.0492 22.7445 15.9482 17.3245 28.0858
# Test 2: Robust linear model with correlations
cat("2. Robust Linear Model (rlm) with correlations:\n")
## 2. Robust Linear Model (rlm) with correlations:
## === SHAP Analysis Summary ===
## Model type: rlm
## Response centered: TRUE
## Correlation-aware: TRUE
## Base value: 18.5975
## Validation passed: FALSE
## Max residual: 7.903986e-01
##
## Feature means:
## wt hp qsec
## 3.4412 157.0500 17.8145
##
## SHAP values (first few observations):
## wt hp qsec
## [1,] 2.0299 0.8603 -0.4013
## [2,] -0.1211 0.6161 0.9803
## [3,] 1.1926 1.4873 0.9902
## [4,] -2.4383 -0.2452 -0.2433
## [5,] -0.6410 -0.1960 0.0512
## [6,] 7.4959 1.7437 0.4040
# Test 3: Quantile regression
cat("3. Quantile Regression (rq):\n")
## 3. Quantile Regression (rq):
## === SHAP Analysis Summary ===
## Model type: rq
## Response centered: TRUE
## Correlation-aware: FALSE
## Base value: 18.6546
## Validation passed: TRUE
## Max residual: 7.105427e-15
##
## Feature means:
## wt hp qsec
## 3.4412 157.0500 17.8145
##
## SHAP values (first few observations):
## wt hp qsec
## Mazda RX4 Wag 2.2088 0.4239 -0.3828
## Valiant -0.0733 0.4690 1.1590
## Merc 240D 0.9800 0.8564 1.0530
## Merc 450SE -2.4530 -0.2068 -0.1997
## Merc 450SL -1.1266 -0.2068 -0.1034
## Honda Civic 7.1242 0.9465 0.3399
# Test 4: Elastic net (glmnet)
cat("4. Elastic Net (glmnet):\n")
## 4. Elastic Net (glmnet):
## === SHAP Analysis Summary ===
## Model type: glmnet
## Response centered: TRUE
## Correlation-aware: FALSE
## Base value: 18.94
## Validation passed: TRUE
## Max residual: 3.552714e-15
##
## Feature means:
## wt hp qsec
## 3.4412 157.0500 17.8145
##
## SHAP values (first few observations):
## wt hp qsec
## Mazda RX4 Wag 2.0584 0.9155 -0.3651
## Valiant -0.0683 1.0128 1.1055
## Merc 240D 0.9132 1.8494 1.0044
## Merc 450SE -2.2860 -0.4466 -0.1905
## Merc 450SL -1.0499 -0.4466 -0.0986
## Honda Civic 6.6392 2.0440 0.3242
# Comparison of interventional vs correlation-aware SHAP
cat("=== Comparison: Interventional vs Correlation-Aware SHAP ===\n")
## === Comparison: Interventional vs Correlation-Aware SHAP ===
result_interventional <- fit_predict_shap(X_train, y_train, X_test[1:3, ],
model_type = "lm",
account_correlations = FALSE)
result_corr_aware <- fit_predict_shap(X_train, y_train, X_test[1:3, ],
model_type = "lm",
account_correlations = TRUE,
n_samples = 100)
cat("\nInterventional SHAP (first 3 observations):\n")
##
## Interventional SHAP (first 3 observations):
print(round(result_interventional$shap_values, 4))
## wt hp qsec
## Mazda RX4 Wag 2.1232 0.8691 -0.4024
## Valiant -0.0705 0.9614 1.2183
## Merc 240D 0.9420 1.7557 1.1068
cat("\nCorrelation-aware SHAP (first 3 observations):\n")
##
## Correlation-aware SHAP (first 3 observations):
## wt hp qsec
## [1,] 2.2556 0.8688 -0.3642
## [2,] -0.2411 0.9089 1.1454
## [3,] 0.9782 1.6341 1.1042
cat("\nDifference (Correlation-aware - Interventional):\n")
##
## Difference (Correlation-aware - Interventional):
print(round(result_corr_aware$shap_values - result_interventional$shap_values, 4))
## wt hp qsec
## [1,] 0.1324 -0.0003 0.0382
## [2,] -0.1706 -0.0526 -0.0728
## [3,] 0.0363 -0.1217 -0.0026