Provides a consistent interface for various machine learning models in R, with automatic detection of formula vs matrix interfaces, built-in cross-validation, model interpretability, and visualization.
An R6 class that provides a unified interface for regression and classification models with automatic interface detection, cross-validation, and interpretability features. The task type (regression vs classification) is automatically detected from the response variable type.
model_fnThe modeling function (e.g., glmnet::glmnet, randomForest::randomForest)
fittedThe fitted model object
taskType of task: "regression" or "classification" (automatically detected)
X_trainTraining features matrix
y_trainTraining target vector
fit()Fit the model to training data
Automatically detects task type (regression vs classification) based on the type of the response variable y. Numeric y -> regression, factor y -> classification.
predict()Generate predictions from fitted model
print()Print model information
summary()Compute numerical derivatives and statistical significance
Uses finite differences to compute approximate partial derivatives for each feature, providing model-agnostic interpretability.
hStep size for finite differences (default: 0.01)
alphaSignificance level for p-values (default: 0.05)
plot()Create partial dependence plot for a feature
Visualizes the relationship between a feature and the predicted outcome while holding other features at their mean values.
clone_model()Create a deep copy of the model
Useful for cross-validation and parallel processing where multiple independent model instances are needed.
# \donttest{
# Regression example with glmnet
library(glmnet)
#> Loading required package: Matrix
#> Loaded glmnet 4.1-10
X <- matrix(rnorm(100), ncol = 4)
y <- 2*X[,1] - 1.5*X[,2] + rnorm(25) # numeric -> regression
mod <- Model$new(glmnet::glmnet)
mod$fit(X, y, alpha = 0, lambda = 0.1)
mod$summary()
#>
#> Model Summary - Numerical Derivatives
#> ======================================
#> Task: regression
#> Samples: 25 | Features: 4
#> Step size (h): 0.01
#>
#> Feature Mean_Derivative Std_Error t_value p_value Significance
#> X1 1.91049526 3.046816e-15 6.270466e+14 0.00000e+00 ***
#> X2 -0.81015913 2.124291e-15 -3.813786e+14 0.00000e+00 ***
#> X3 0.05898194 2.266233e-16 2.602642e+14 0.00000e+00 ***
#> X4 -0.20547449 1.952581e-15 -1.052323e+14 1.72923e-321 ***
#>
#> Significance codes: 0 '***' 0.01 '**' 0.05 '*' 0.1 ' ' 1
predictions <- mod$predict(X)
# Classification example
data(iris)
iris_binary <- iris[iris$Species %in% c("setosa", "versicolor"), ]
X_class <- as.matrix(iris_binary[, 1:4])
y_class <- droplevels(iris_binary$Species) # factor -> classification
mod2 <- Model$new(e1071::svm)
mod2$fit(X_class, y_class, kernel = "radial")
predictions <- mod2$predict(X_class)
mod2$predict_proba(X_class)
#> setosa versicolor
#> 1 1 0
#> 2 1 0
#> 3 1 0
#> 4 1 0
#> 5 1 0
#> 6 1 0
#> 7 1 0
#> 8 1 0
#> 9 1 0
#> 10 1 0
#> 11 1 0
#> 12 1 0
#> 13 1 0
#> 14 1 0
#> 15 1 0
#> 16 1 0
#> 17 1 0
#> 18 1 0
#> 19 1 0
#> 20 1 0
#> 21 1 0
#> 22 1 0
#> 23 1 0
#> 24 1 0
#> 25 1 0
#> 26 1 0
#> 27 1 0
#> 28 1 0
#> 29 1 0
#> 30 1 0
#> 31 1 0
#> 32 1 0
#> 33 1 0
#> 34 1 0
#> 35 1 0
#> 36 1 0
#> 37 1 0
#> 38 1 0
#> 39 1 0
#> 40 1 0
#> 41 1 0
#> 42 1 0
#> 43 1 0
#> 44 1 0
#> 45 1 0
#> 46 1 0
#> 47 1 0
#> 48 1 0
#> 49 1 0
#> 50 1 0
#> 51 0 1
#> 52 0 1
#> 53 0 1
#> 54 0 1
#> 55 0 1
#> 56 0 1
#> 57 0 1
#> 58 0 1
#> 59 0 1
#> 60 0 1
#> 61 0 1
#> 62 0 1
#> 63 0 1
#> 64 0 1
#> 65 0 1
#> 66 0 1
#> 67 0 1
#> 68 0 1
#> 69 0 1
#> 70 0 1
#> 71 0 1
#> 72 0 1
#> 73 0 1
#> 74 0 1
#> 75 0 1
#> 76 0 1
#> 77 0 1
#> 78 0 1
#> 79 0 1
#> 80 0 1
#> 81 0 1
#> 82 0 1
#> 83 0 1
#> 84 0 1
#> 85 0 1
#> 86 0 1
#> 87 0 1
#> 88 0 1
#> 89 0 1
#> 90 0 1
#> 91 0 1
#> 92 0 1
#> 93 0 1
#> 94 0 1
#> 95 0 1
#> 96 0 1
#> 97 0 1
#> 98 0 1
#> 99 0 1
#> 100 0 1
#> attr(,"assign")
#> [1] 1 1
#> attr(,"contrasts")
#> attr(,"contrasts")$pred
#> [1] "contr.treatment"
#>
#> attr(,"extraction_method")
#> [1] "fallback::1"
#> attr(,"model_class")
#> [1] "svm.formula"
# }