Parameters description can be found at https://techtonique.github.io/nnetsauce/

RandomBagClassifier(
  obj,
  n_estimators = 50L,
  n_hidden_features = 5L,
  activation_name = "relu",
  a = 0.01,
  nodes_sim = "sobol",
  bias = TRUE,
  dropout = 0,
  direct_link = FALSE,
  n_clusters = 2L,
  cluster_encode = TRUE,
  type_clust = "kmeans",
  col_sample = 1,
  row_sample = 1,
  n_jobs = NULL,
  seed = 123L,
  verbose = 1L,
  backend = c("cpu", "gpu", "tpu")
)

Examples


library(datasets)

X <- as.matrix(iris[, 1:4])
y <- as.integer(iris[, 5]) - 1L

(index_train <- base::sample.int(n = nrow(X),
                                 size = floor(0.8*nrow(X)),
                                 replace = FALSE))
#>   [1] 127  41   8  52  67   2  65 135  29  20 124  80   7  45  76  25  91  94
#>  [19] 150  12 148  71 121  15  83  49 107  39 115  87  35   6 129   4  74  81
#>  [37]  48 130  53 146 108 112 101  14  28 119   9  33  46  64  19  98   1  92
#>  [55] 147 133  68  90  75   5 100 120 103  69  61  23 149  63 125 128 110 136
#>  [73] 145  18 117  72 144  11  57  56 132  96  86  36   3 138  79 118  40 109
#>  [91]  37  24  38 139  66  73  10 142  85 134  59  32 143  55  54  70  31 102
#> [109]  50  89  43 113  82 116  17 141  88  62  13  34
X_train <- X[index_train, ]
y_train <- y[index_train]
X_test <- X[-index_train, ]
y_test <- y[-index_train]

obj <- sklearn$tree$DecisionTreeClassifier()
obj2 <- RandomBagClassifier(obj, n_estimators=50L,
n_hidden_features=5L)
obj2$fit(X_train, y_train)
#> RandomBagClassifier(col_sample=1.0, dropout=0.0, n_estimators=50,
#>                     n_hidden_features=5, obj=DecisionTreeClassifier(),
#>                     row_sample=1.0)
print(obj2$score(X_test, y_test))
#> [1] 0.7333333
print(obj2$predict_proba(X_test))
#>            [,1]      [,2] [,3]
#>  [1,] 1.0000000 0.0000000 0.00
#>  [2,] 1.0000000 0.0000000 0.00
#>  [3,] 1.0000000 0.0000000 0.00
#>  [4,] 0.8861416 0.1138584 0.00
#>  [5,] 1.0000000 0.0000000 0.00
#>  [6,] 0.8861416 0.1138584 0.00
#>  [7,] 0.8861416 0.1138584 0.00
#>  [8,] 1.0000000 0.0000000 0.00
#>  [9,] 1.0000000 0.0000000 0.00
#> [10,] 0.0000000 0.4000000 0.60
#> [11,] 0.8861416 0.1138584 0.00
#> [12,] 0.0000000 0.4400000 0.56
#> [13,] 0.0000000 0.1000000 0.90
#> [14,] 0.0000000 0.0000000 1.00
#> [15,] 0.0000000 1.0000000 0.00
#> [16,] 0.0000000 1.0000000 0.00
#> [17,] 0.0000000 1.0000000 0.00
#> [18,] 0.0000000 1.0000000 0.00
#> [19,] 0.8861416 0.1138584 0.00
#> [20,] 0.0000000 0.9000000 0.10
#> [21,] 0.0000000 0.0000000 1.00
#> [22,] 0.0000000 0.0000000 1.00
#> [23,] 0.0000000 0.0000000 1.00
#> [24,] 0.0000000 0.8000000 0.20
#> [25,] 0.0000000 0.1200000 0.88
#> [26,] 0.0000000 0.0000000 1.00
#> [27,] 0.0000000 0.0000000 1.00
#> [28,] 0.0000000 0.0000000 1.00
#> [29,] 0.0000000 0.0000000 1.00
#> [30,] 0.0000000 0.0000000 1.00