Introduction

The tisthemachinelearner package provides a simple R interface to scikit-learn models through Python’s tisthemachinelearner package. This vignette demonstrates how to use the package with R’s built-in mtcars dataset.

Setup

First, let’s load the required packages:

library(tisthemachinelearner)
#> Loading required package: reticulate
#> Loading required package: Matrix
library(reticulate)

Data Preparation

We’ll use the classic mtcars dataset to predict miles per gallon (mpg) based on other car characteristics:

# Load data

# Split features and target
X <- as.matrix(MASS::Boston[, -14])  # all columns except mpg
y <- MASS::Boston[, 14]              # mpg column

# Create train/test split
set.seed(42)
train_idx <- sample(nrow(X), size = floor(0.8 * nrow(X)))
X_train <- X[train_idx, ]
X_test <- X[-train_idx, ]
y_train <- y[train_idx]
y_test <- y[-train_idx]

Ridge Regression with Cross-Validation

Now let’s try Ridge regression with cross-validation for hyperparameter tuning:

# Fit booster model
time <- proc.time()[3]
reg_booster <- tisthemachinelearner::booster(X_train, y_train, "ExtraTreeRegressor",
                                            n_estimators = 100L,
                                            learning_rate = 0.1,
                                            show_progress = TRUE,
                                            verbose = TRUE)
#> Iteration 1: loss = 21.6277
#> Iteration 2: loss = 19.465
#> Iteration 3: loss = 17.5185
#> Iteration 4: loss = 15.7666
#> Iteration 5: loss = 14.19
#> Iteration 6: loss = 12.771
#> Iteration 7: loss = 11.4939
#> Iteration 8: loss = 10.3445
#> Iteration 9: loss = 9.31003
#> Iteration 10: loss = 8.37903
#> Iteration 11: loss = 7.54112
#> Iteration 12: loss = 6.78701
#> Iteration 13: loss = 6.10831
#> Iteration 14: loss = 5.49748
#> Iteration 15: loss = 4.94773
#> Iteration 16: loss = 4.45296
#> Iteration 17: loss = 4.00766
#> Iteration 18: loss = 3.6069
#> Iteration 19: loss = 3.24621
#> Iteration 20: loss = 2.92159
#> Iteration 21: loss = 2.62943
#> Iteration 22: loss = 2.36648
#> Iteration 23: loss = 2.12984
#> Iteration 24: loss = 1.91685
#> Iteration 25: loss = 1.72517
#> Iteration 26: loss = 1.55265
#> Iteration 27: loss = 1.39739
#> Iteration 28: loss = 1.25765
#> Iteration 29: loss = 1.13188
#> Iteration 30: loss = 1.01869
#> Iteration 31: loss = 0.916825
#> Iteration 32: loss = 0.825142
#> Iteration 33: loss = 0.742628
#> Iteration 34: loss = 0.668365
#> Iteration 35: loss = 0.601529
#> Iteration 36: loss = 0.541376
#> Iteration 37: loss = 0.487238
#> Iteration 38: loss = 0.438514
#> Iteration 39: loss = 0.394663
#> Iteration 40: loss = 0.355197
#> Iteration 41: loss = 0.319677
#> Iteration 42: loss = 0.287709
#> Iteration 43: loss = 0.258938
#> Iteration 44: loss = 0.233045
#> Iteration 45: loss = 0.20974
#> Iteration 46: loss = 0.188766
#> Iteration 47: loss = 0.169889
#> Iteration 48: loss = 0.152901
#> Iteration 49: loss = 0.13761
#> Iteration 50: loss = 0.123849
#> Iteration 51: loss = 0.111464
#> Iteration 52: loss = 0.100318
#> Iteration 53: loss = 0.0902862
#> Iteration 54: loss = 0.0812576
#> Iteration 55: loss = 0.0731318
#> Iteration 56: loss = 0.0658187
#> Iteration 57: loss = 0.0592368
#> Iteration 58: loss = 0.0533131
#> Iteration 59: loss = 0.0479818
#> Iteration 60: loss = 0.0431836
#> Iteration 61: loss = 0.0388653
#> Iteration 62: loss = 0.0349787
#> Iteration 63: loss = 0.0314809
#> Iteration 64: loss = 0.0283328
#> Iteration 65: loss = 0.0254995
#> Iteration 66: loss = 0.0229495
#> Iteration 67: loss = 0.0206546
#> Iteration 68: loss = 0.0185891
#> Iteration 69: loss = 0.0167302
#> Iteration 70: loss = 0.0150572
#> Iteration 71: loss = 0.0135515
#> Iteration 72: loss = 0.0121963
#> Iteration 73: loss = 0.0109767
#> Iteration 74: loss = 0.00987903
#> Iteration 75: loss = 0.00889113
#> Iteration 76: loss = 0.00800201
#> Iteration 77: loss = 0.00720181
#> Iteration 78: loss = 0.00648163
#> Iteration 79: loss = 0.00583347
#> Iteration 80: loss = 0.00525012
#> Iteration 81: loss = 0.00472511
#> Iteration 82: loss = 0.0042526
#> Iteration 83: loss = 0.00382734
#> Iteration 84: loss = 0.0034446
#> Iteration 85: loss = 0.00310014
#> Iteration 86: loss = 0.00279013
#> Iteration 87: loss = 0.00251112
#> Iteration 88: loss = 0.00226
#> Iteration 89: loss = 0.002034
#> Iteration 90: loss = 0.0018306
#> Iteration 91: loss = 0.00164754
#> Iteration 92: loss = 0.00148279
#> Iteration 93: loss = 0.00133451
#> Iteration 94: loss = 0.00120106
#> Iteration 95: loss = 0.00108095
#> Iteration 96: loss = 0.000972858
#> Iteration 97: loss = 0.000875572
time <- proc.time()[3] - time
cat("Time taken:", time, "seconds\n")
#> Time taken: 32.64 seconds

# Make predictions
time <- proc.time()[3]
predictions <- predict(reg_booster, X_test)
time <- proc.time()[3] - time
cat("Time taken:", time, "seconds\n")
#> Time taken: 0.073 seconds

# RMSE 
rmse <- sqrt(mean((y_test - predictions)^2))
cat("RMSE:", rmse, "\n")
#> RMSE: 2.841826

Session Info

sessionInfo()
#> R version 4.3.3 (2024-02-29)
#> Platform: x86_64-apple-darwin20 (64-bit)
#> Running under: macOS Sonoma 14.2
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.3-x86_64/Resources/lib/libRblas.0.dylib 
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.3-x86_64/Resources/lib/libRlapack.dylib;  LAPACK version 3.11.0
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> time zone: Europe/Paris
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] tisthemachinelearner_0.3.0 Matrix_1.6-5              
#> [3] reticulate_1.42.0         
#> 
#> loaded via a namespace (and not attached):
#>  [1] cli_3.6.4         knitr_1.49        rlang_1.1.6       xfun_0.51        
#>  [5] png_0.1-8         textshaping_1.0.0 jsonlite_2.0.0    htmltools_0.5.8.1
#>  [9] ragg_1.3.3        sass_0.4.9        rmarkdown_2.29    grid_4.3.3       
#> [13] evaluate_1.0.3    jquerylib_0.1.4   MASS_7.3-60.0.1   fastmap_1.2.0    
#> [17] yaml_2.3.10       lifecycle_1.0.4   compiler_4.3.3    fs_1.6.5         
#> [21] htmlwidgets_1.6.4 Rcpp_1.0.14       systemfonts_1.1.0 lattice_0.22-5   
#> [25] digest_0.6.37     R6_2.6.1          bslib_0.9.0       tools_4.3.3      
#> [29] pkgdown_2.1.1     cachem_1.1.0      desc_1.4.3