Provides a consistent interface for various machine learning models in R, with automatic detection of formula vs matrix interfaces, built-in cross-validation, model interpretability, and visualization.

An R6 class that provides a unified interface for regression and classification models with automatic interface detection, cross-validation, and interpretability features. The task type (regression vs classification) is automatically detected from the response variable type.

Author

Your Name

Public fields

model_fn

The modeling function (e.g., glmnet::glmnet, randomForest::randomForest)

fitted

The fitted model object

task

Type of task: "regression" or "classification" (automatically detected)

X_train

Training features matrix

y_train

Training target vector

Methods


Method new()

Initialize a new Model

Usage

Model$new(model_fn)

Arguments

model_fn

A modeling function (e.g., glmnet, randomForest, svm)

Returns

A new Model object


Method fit()

Fit the model to training data

Automatically detects task type (regression vs classification) based on the type of the response variable y. Numeric y -> regression, factor y -> classification.

Usage

Model$fit(X, y, ...)

Arguments

X

Feature matrix or data.frame

y

Target vector (numeric for regression, factor for classification)

...

Additional arguments passed to the model function

Returns

self (invisible) for method chaining


Method predict()

Generate predictions from fitted model

Usage

Model$predict(X, type = NULL, ...)

Arguments

X

Feature matrix for prediction

type

Type of prediction ("response", "class", "probabilities")

...

Additional arguments passed to predict function

Returns

Vector of predictions


Method print()

Print model information

Usage

Model$print()

Returns

self (invisible) for method chaining


Method summary()

Compute numerical derivatives and statistical significance

Uses finite differences to compute approximate partial derivatives for each feature, providing model-agnostic interpretability.

Usage

Model$summary(h = 0.01, alpha = 0.05)

Arguments

h

Step size for finite differences (default: 0.01)

alpha

Significance level for p-values (default: 0.05)

Details

The method computes numerical derivatives using central differences.

Statistical significance is assessed using t-tests on the derivative estimates across samples.

Returns

A data.frame with derivative statistics (invisible)


Method plot()

Create partial dependence plot for a feature

Visualizes the relationship between a feature and the predicted outcome while holding other features at their mean values.

Usage

Model$plot(feature = 1, n_points = 100)

Arguments

feature

Index or name of feature to plot

n_points

Number of points for the grid (default: 100)

Returns

self (invisible) for method chaining


Method clone_model()

Create a deep copy of the model

Useful for cross-validation and parallel processing where multiple independent model instances are needed.

Usage

Model$clone_model()

Returns

A new Model object with same configuration


Method clone()

The objects of this class are cloneable with this method.

Usage

Model$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

if (FALSE) { # \dontrun{
# Regression example with glmnet
library(glmnet)
X <- matrix(rnorm(100), ncol = 4)
y <- 2*X[,1] - 1.5*X[,2] + rnorm(25)  # numeric → regression

mod <- Model$new(glmnet::glmnet)
mod$fit(X, y, alpha = 0, lambda = 0.1)
mod$summary()
predictions <- mod$predict(X)

# Classification example  
data(iris)
iris_binary <- iris[iris$Species %in% c("setosa", "versicolor"), ]
X_class <- as.matrix(iris_binary[, 1:4])
y_class <- iris_binary$Species  # factor → classification

mod2 <- Model$new(e1071::svm)
mod2$fit(X_class, y_class, kernel = "radial")
mod2$summary()

# Cross-validation
cv_scores <- cross_val_score(mod, X, y, cv = 5)
} # }