Example usage

library(misc)
## Loading required package: MASS
## Loading required package: foreach
## Loading required package: doSNOW
## Loading required package: iterators
## Loading required package: snow
## 
## Attaching package: 'misc'
## The following object is masked from 'package:stats':
## 
##     integrate
# Prepare test data
data(mtcars)
set.seed(42)
train_idx <- sample(nrow(mtcars), 20)
X_train <- mtcars[train_idx, c("wt", "hp", "qsec")]
y_train <- mtcars[train_idx, "mpg"]
X_test <- mtcars[-train_idx, c("wt", "hp", "qsec")]
y_test <- mtcars[-train_idx, "mpg"]

# Let's manually verify with a simple example
simple_X <- X_test[1:2, ]
result_debug <- fit_predict_shap(X_train, y_train, simple_X, 
                                 model_type = "lm", center_response = FALSE)

cat("Simple example (no response centering):\n")
## Simple example (no response centering):
cat("Coefficients:", round(coef(result_debug$model), 4), "\n")
## Coefficients: 25.723 -3.7499 -0.0185 0.5064
cat("Feature means:", round(result_debug$feature_means, 4), "\n")
## Feature means: 3.4412 157.05 17.8145
cat("Expected prediction at means:", round(result_debug$base_value, 4), "\n")
## Expected prediction at means: 18.94
cat("Test data:\n")
## Test data:
print(simple_X)
##                  wt  hp  qsec
## Mazda RX4 Wag 2.875 110 17.02
## Valiant       3.460 105 20.22
cat("Predictions:", round(result_debug$predictions, 4), "\n")
## Predictions: 21.5299 21.0492
cat("Base value:", round(result_debug$base_value, 4), "\n")
## Base value: 18.94
cat("SHAP values:\n")
## SHAP values:
print(round(result_debug$shap_values, 4))
##                    wt     hp    qsec
## Mazda RX4 Wag  2.1232 0.8691 -0.4024
## Valiant       -0.0705 0.9614  1.2183
cat("Base + SHAP sum:", round(result_debug$base_value + rowSums(result_debug$shap_values), 4), "\n")
## Base + SHAP sum: 21.5299 21.0492
cat("Residuals:", round(result_debug$validation$residuals, 4), "\n")
## Residuals: 0 0
# Manual calculation
coefs <- coef(result_debug$model)
intercept <- coefs["(Intercept)"]
slopes <- coefs[names(coefs) != "(Intercept)"]
manual_pred <- intercept + as.matrix(simple_X) %*% slopes
manual_base <- intercept + sum(slopes * result_debug$feature_means)
cat("Manual prediction:", round(as.numeric(manual_pred), 4), "\n")
## Manual prediction: 21.5299 21.0492
cat("Manual base value:", round(as.numeric(manual_base), 4), "\n")
## Manual base value: 18.94
# =============================================================================
# ORIGINAL EXAMPLE USAGE AND TESTING  
# =============================================================================

cat("=== Testing Different Model Types ===\n\n")
## === Testing Different Model Types ===
# Test 1: Standard linear model
cat("1. Linear Model (lm):\n")
## 1. Linear Model (lm):
result_lm <- fit_predict_shap(X_train, y_train, X_test, 
                              model_type = "lm", center_response = TRUE)
print_shap_summary(result_lm)
## === SHAP Analysis Summary ===
## Model type: lm 
## Response centered: TRUE 
## Correlation-aware: FALSE 
## Base value: 18.94 
## Validation passed: TRUE 
## Max residual: 3.552714e-15 
## 
## Feature means:
##       wt       hp     qsec 
##   3.4412 157.0500  17.8145 
## 
## SHAP values (first few observations):
##                    wt      hp    qsec
## Mazda RX4 Wag  2.1232  0.8691 -0.4024
## Valiant       -0.0705  0.9614  1.2183
## Merc 240D      0.9420  1.7557  1.1068
## Merc 450SE    -2.3579 -0.4239 -0.2099
## Merc 450SL    -1.0830 -0.4239 -0.1086
## Honda Civic    6.8480  1.9404  0.3573
# Debug: Let's manually check the decomposition for lm
cat("\nManual verification for lm:\n")
## 
## Manual verification for lm:
manual_pred <- predict(result_lm$model, newdata = X_test)
cat("Model predictions:", round(head(manual_pred + result_lm$response_mean), 4), "\n")
## Model predictions: 21.5299 21.0492 22.7445 15.9482 17.3245 28.0858
cat("Our predictions:", round(head(result_lm$predictions), 4), "\n")
## Our predictions: 21.5299 21.0492 22.7445 15.9482 17.3245 28.0858
cat("Base + sum(SHAP):", round(head(result_lm$base_value + rowSums(result_lm$shap_values)), 4), "\n")
## Base + sum(SHAP): 21.5299 21.0492 22.7445 15.9482 17.3245 28.0858
cat("\n")
# Test 2: Robust linear model with correlations
cat("2. Robust Linear Model (rlm) with correlations:\n")
## 2. Robust Linear Model (rlm) with correlations:
result_rlm <- fit_predict_shap(X_train, y_train, X_test, 
                               model_type = "rlm", 
                               account_correlations = TRUE,
                               n_samples = 50)
print_shap_summary(result_rlm)
## === SHAP Analysis Summary ===
## Model type: rlm 
## Response centered: TRUE 
## Correlation-aware: TRUE 
## Base value: 18.5975 
## Validation passed: FALSE 
## Max residual: 7.903986e-01 
## 
## Feature means:
##       wt       hp     qsec 
##   3.4412 157.0500  17.8145 
## 
## SHAP values (first few observations):
##           wt      hp    qsec
## [1,]  2.0299  0.8603 -0.4013
## [2,] -0.1211  0.6161  0.9803
## [3,]  1.1926  1.4873  0.9902
## [4,] -2.4383 -0.2452 -0.2433
## [5,] -0.6410 -0.1960  0.0512
## [6,]  7.4959  1.7437  0.4040
cat("\n")
# Test 3: Quantile regression
cat("3. Quantile Regression (rq):\n")
## 3. Quantile Regression (rq):
result_rq <- fit_predict_shap(X_train, y_train, X_test, 
                              model_type = "rq", tau = 0.5)
print_shap_summary(result_rq)
## === SHAP Analysis Summary ===
## Model type: rq 
## Response centered: TRUE 
## Correlation-aware: FALSE 
## Base value: 18.6546 
## Validation passed: TRUE 
## Max residual: 7.105427e-15 
## 
## Feature means:
##       wt       hp     qsec 
##   3.4412 157.0500  17.8145 
## 
## SHAP values (first few observations):
##                    wt      hp    qsec
## Mazda RX4 Wag  2.2088  0.4239 -0.3828
## Valiant       -0.0733  0.4690  1.1590
## Merc 240D      0.9800  0.8564  1.0530
## Merc 450SE    -2.4530 -0.2068 -0.1997
## Merc 450SL    -1.1266 -0.2068 -0.1034
## Honda Civic    7.1242  0.9465  0.3399
cat("\n")
# Test 4: Elastic net (glmnet)
cat("4. Elastic Net (glmnet):\n")
## 4. Elastic Net (glmnet):
result_glmnet <- fit_predict_shap(X_train, y_train, X_test, 
                                  model_type = "glmnet", 
                                  lambda = 0.1, alpha = 0.5)
print_shap_summary(result_glmnet)
## === SHAP Analysis Summary ===
## Model type: glmnet 
## Response centered: TRUE 
## Correlation-aware: FALSE 
## Base value: 18.94 
## Validation passed: TRUE 
## Max residual: 3.552714e-15 
## 
## Feature means:
##       wt       hp     qsec 
##   3.4412 157.0500  17.8145 
## 
## SHAP values (first few observations):
##                    wt      hp    qsec
## Mazda RX4 Wag  2.0584  0.9155 -0.3651
## Valiant       -0.0683  1.0128  1.1055
## Merc 240D      0.9132  1.8494  1.0044
## Merc 450SE    -2.2860 -0.4466 -0.1905
## Merc 450SL    -1.0499 -0.4466 -0.0986
## Honda Civic    6.6392  2.0440  0.3242
cat("\n")
# Comparison of interventional vs correlation-aware SHAP
cat("=== Comparison: Interventional vs Correlation-Aware SHAP ===\n")
## === Comparison: Interventional vs Correlation-Aware SHAP ===
result_interventional <- fit_predict_shap(X_train, y_train, X_test[1:3, ], 
                                          model_type = "lm", 
                                          account_correlations = FALSE)

result_corr_aware <- fit_predict_shap(X_train, y_train, X_test[1:3, ], 
                                      model_type = "lm", 
                                      account_correlations = TRUE,
                                      n_samples = 100)

cat("\nInterventional SHAP (first 3 observations):\n")
## 
## Interventional SHAP (first 3 observations):
print(round(result_interventional$shap_values, 4))
##                    wt     hp    qsec
## Mazda RX4 Wag  2.1232 0.8691 -0.4024
## Valiant       -0.0705 0.9614  1.2183
## Merc 240D      0.9420 1.7557  1.1068
cat("\nCorrelation-aware SHAP (first 3 observations):\n")
## 
## Correlation-aware SHAP (first 3 observations):
print(round(result_corr_aware$shap_values, 4))
##           wt     hp    qsec
## [1,]  2.2556 0.8688 -0.3642
## [2,] -0.2411 0.9089  1.1454
## [3,]  0.9782 1.6341  1.1042
cat("\nDifference (Correlation-aware - Interventional):\n")
## 
## Difference (Correlation-aware - Interventional):
print(round(result_corr_aware$shap_values - result_interventional$shap_values, 4))
##           wt      hp    qsec
## [1,]  0.1324 -0.0003  0.0382
## [2,] -0.1706 -0.0526 -0.0728
## [3,]  0.0363 -0.1217 -0.0026