## Loading required package: Matrix
## Loaded glmnet 4.1-10
set.seed(123)
library(glmnet)
data(Boston, package = "MASS")
# -------------------------
# Data
# -------------------------
X <- as.matrix(Boston[, -14])
y <- Boston$medv
n <- nrow(X)
idx <- sample(1:n, size = round(0.8 * n))
X_train <- X[idx, ]
y_train <- y[idx]
X_test <- X[-idx, ]
y_test <- y[-idx]
# -------------------------
# Grid
# -------------------------
grid <- expand.grid(
n_hidden = c(175, 200, 225, 250),
alpha = seq(0.1, 0.5, by=0.2),
include_original = c(TRUE, FALSE),
seed = 1,
stringsAsFactors = FALSE
)
results <- vector("list", nrow(grid))
# -------------------------
# Loop
# -------------------------
for (i in seq_len(nrow(grid))) {
params <- grid[i, ]
cat("\n========================================\n")
cat(sprintf("Run %d / %d\n", i, nrow(grid)))
print(params)
# -------------------------
# Fit model
# -------------------------
fit <- rvflnet(
X_train, y_train,
n_hidden = params$n_hidden,
activation = "sigmoid",
W_type = "gaussian",
seed = params$seed,
include_original = params$include_original,
alpha = params$alpha
)
# -------------------------
# Evaluate full lambda path
# -------------------------
lambdas <- fit$fit$lambda
preds <- predict(fit, newx = X_test, s = lambdas)
rmse_path <- sqrt(colMeans((preds - y_test)^2))
best_idx <- which.min(rmse_path)
best_rmse <- rmse_path[best_idx]
best_lambda <- lambdas[best_idx]
# -------------------------
# Sparsity
# -------------------------
coef_mat <- coef(fit, s = best_lambda)
nonzero <- sum(coef_mat[-1, 1] != 0)
# -------------------------
# Verbose output
# -------------------------
cat(sprintf("Best RMSE: %.4f\n", best_rmse))
cat(sprintf("Best lambda: %.6f\n", best_lambda))
cat(sprintf("Non-zero coeffs: %d\n", nonzero))
# -------------------------
# Store
# -------------------------
results[[i]] <- data.frame(
n_hidden = params$n_hidden,
alpha = params$alpha,
include_original = params$include_original,
seed = params$seed,
rmse = best_rmse,
lambda = best_lambda,
nonzero = nonzero
)
}
##
## ========================================
## Run 1 / 24
## n_hidden alpha include_original seed
## 1 175 0.1 TRUE 1
## Best RMSE: 2.9200
## Best lambda: 0.036435
## Non-zero coeffs: 165
##
## ========================================
## Run 2 / 24
## n_hidden alpha include_original seed
## 2 200 0.1 TRUE 1
## Best RMSE: 2.8819
## Best lambda: 0.027562
## Non-zero coeffs: 190
##
## ========================================
## Run 3 / 24
## n_hidden alpha include_original seed
## 3 225 0.1 TRUE 1
## Best RMSE: 3.0342
## Best lambda: 0.063671
## Non-zero coeffs: 183
##
## ========================================
## Run 4 / 24
## n_hidden alpha include_original seed
## 4 250 0.1 TRUE 1
## Best RMSE: 3.0472
## Best lambda: 0.052861
## Non-zero coeffs: 195
##
## ========================================
## Run 5 / 24
## n_hidden alpha include_original seed
## 5 175 0.3 TRUE 1
## Best RMSE: 2.9413
## Best lambda: 0.023293
## Non-zero coeffs: 144
##
## ========================================
## Run 6 / 24
## n_hidden alpha include_original seed
## 6 200 0.3 TRUE 1
## Best RMSE: 2.8847
## Best lambda: 0.017620
## Non-zero coeffs: 167
##
## ========================================
## Run 7 / 24
## n_hidden alpha include_original seed
## 7 225 0.3 TRUE 1
## Best RMSE: 3.0079
## Best lambda: 0.037089
## Non-zero coeffs: 149
##
## ========================================
## Run 8 / 24
## n_hidden alpha include_original seed
## 8 250 0.3 TRUE 1
## Best RMSE: 3.0234
## Best lambda: 0.025564
## Non-zero coeffs: 164
##
## ========================================
## Run 9 / 24
## n_hidden alpha include_original seed
## 9 175 0.5 TRUE 1
## Best RMSE: 2.9385
## Best lambda: 0.016834
## Non-zero coeffs: 136
##
## ========================================
## Run 10 / 24
## n_hidden alpha include_original seed
## 10 200 0.5 TRUE 1
## Best RMSE: 2.8893
## Best lambda: 0.012734
## Non-zero coeffs: 158
##
## ========================================
## Run 11 / 24
## n_hidden alpha include_original seed
## 11 225 0.5 TRUE 1
## Best RMSE: 2.9942
## Best lambda: 0.024423
## Non-zero coeffs: 137
##
## ========================================
## Run 12 / 24
## n_hidden alpha include_original seed
## 12 250 0.5 TRUE 1
## Best RMSE: 3.0323
## Best lambda: 0.018475
## Non-zero coeffs: 151
##
## ========================================
## Run 13 / 24
## n_hidden alpha include_original seed
## 13 175 0.1 FALSE 1
## Best RMSE: 3.0930
## Best lambda: 0.028256
## Non-zero coeffs: 162
##
## ========================================
## Run 14 / 24
## n_hidden alpha include_original seed
## 14 200 0.1 FALSE 1
## Best RMSE: 3.1215
## Best lambda: 0.031011
## Non-zero coeffs: 178
##
## ========================================
## Run 15 / 24
## n_hidden alpha include_original seed
## 15 225 0.1 FALSE 1
## Best RMSE: 3.2248
## Best lambda: 0.057806
## Non-zero coeffs: 179
##
## ========================================
## Run 16 / 24
## n_hidden alpha include_original seed
## 16 250 0.1 FALSE 1
## Best RMSE: 3.2307
## Best lambda: 0.047992
## Non-zero coeffs: 205
##
## ========================================
## Run 17 / 24
## n_hidden alpha include_original seed
## 17 175 0.3 FALSE 1
## Best RMSE: 3.0828
## Best lambda: 0.019825
## Non-zero coeffs: 145
##
## ========================================
## Run 18 / 24
## n_hidden alpha include_original seed
## 18 200 0.3 FALSE 1
## Best RMSE: 3.1427
## Best lambda: 0.019825
## Non-zero coeffs: 154
##
## ========================================
## Run 19 / 24
## n_hidden alpha include_original seed
## 19 225 0.3 FALSE 1
## Best RMSE: 3.2248
## Best lambda: 0.036956
## Non-zero coeffs: 131
##
## ========================================
## Run 20 / 24
## n_hidden alpha include_original seed
## 20 250 0.3 FALSE 1
## Best RMSE: 3.2282
## Best lambda: 0.027956
## Non-zero coeffs: 154
##
## ========================================
## Run 21 / 24
## n_hidden alpha include_original seed
## 21 175 0.5 FALSE 1
## Best RMSE: 3.1097
## Best lambda: 0.013055
## Non-zero coeffs: 136
##
## ========================================
## Run 22 / 24
## n_hidden alpha include_original seed
## 22 200 0.5 FALSE 1
## Best RMSE: 3.1669
## Best lambda: 0.011895
## Non-zero coeffs: 153
##
## ========================================
## Run 23 / 24
## n_hidden alpha include_original seed
## 23 225 0.5 FALSE 1
## Best RMSE: 3.2196
## Best lambda: 0.026708
## Non-zero coeffs: 131
##
## ========================================
## Run 24 / 24
## n_hidden alpha include_original seed
## 24 250 0.5 FALSE 1
## Best RMSE: 3.2321
## Best lambda: 0.020204
## Non-zero coeffs: 144
# -------------------------
# Aggregate
# -------------------------
results_df <- do.call(rbind, results)
(results_df <- results_df[order(results_df$rmse), ])
## n_hidden alpha include_original seed rmse lambda
## s= 0.027561759 200 0.1 TRUE 1 2.881935 0.02756176
## s= 0.017620327 200 0.3 TRUE 1 2.884739 0.01762033
## s= 0.012734248 200 0.5 TRUE 1 2.889339 0.01273425
## s= 0.036435024 175 0.1 TRUE 1 2.920012 0.03643502
## s= 0.016833926 175 0.5 TRUE 1 2.938472 0.01683393
## s= 0.023293035 175 0.3 TRUE 1 2.941267 0.02329304
## s= 0.024423144 225 0.5 TRUE 1 2.994151 0.02442314
## s= 0.037089099 225 0.3 TRUE 1 3.007938 0.03708910
## s= 0.025564078 250 0.3 TRUE 1 3.023438 0.02556408
## s= 0.018475213 250 0.5 TRUE 1 3.032290 0.01847521
## s= 0.063671239 225 0.1 TRUE 1 3.034159 0.06367124
## s= 0.052860981 250 0.1 TRUE 1 3.047206 0.05286098
## s= 0.019825282 175 0.3 FALSE 1 3.082774 0.01982528
## s= 0.028255845 175 0.1 FALSE 1 3.093027 0.02825585
## s= 0.013054934 175 0.5 FALSE 1 3.109746 0.01305493
## s= 0.031010755 200 0.1 FALSE 1 3.121483 0.03101076
## s= 0.0198252821 200 0.3 FALSE 1 3.142658 0.01982528
## s= 0.011895169 200 0.5 FALSE 1 3.166913 0.01189517
## s= 0.026708050 225 0.5 FALSE 1 3.219590 0.02670805
## s= 0.057806387 225 0.1 FALSE 1 3.224829 0.05780639
## s= 0.036955821 225 0.3 FALSE 1 3.224837 0.03695582
## s= 0.027955723 250 0.3 FALSE 1 3.228242 0.02795572
## s= 0.047991878 250 0.1 FALSE 1 3.230742 0.04799188
## s= 0.020203661 250 0.5 FALSE 1 3.232122 0.02020366
## nonzero
## s= 0.027561759 190
## s= 0.017620327 167
## s= 0.012734248 158
## s= 0.036435024 165
## s= 0.016833926 136
## s= 0.023293035 144
## s= 0.024423144 137
## s= 0.037089099 149
## s= 0.025564078 164
## s= 0.018475213 151
## s= 0.063671239 183
## s= 0.052860981 195
## s= 0.019825282 145
## s= 0.028255845 162
## s= 0.013054934 136
## s= 0.031010755 178
## s= 0.0198252821 154
## s= 0.011895169 153
## s= 0.026708050 131
## s= 0.057806387 179
## s= 0.036955821 131
## s= 0.027955723 154
## s= 0.047991878 205
## s= 0.020203661 144