RandomBagClassifier.Rd
Parameters description can be found at https://techtonique.github.io/nnetsauce/
RandomBagClassifier(
obj,
n_estimators = 50L,
n_hidden_features = 5L,
activation_name = "relu",
a = 0.01,
nodes_sim = "sobol",
bias = TRUE,
dropout = 0,
direct_link = FALSE,
n_clusters = 2L,
cluster_encode = TRUE,
type_clust = "kmeans",
col_sample = 1,
row_sample = 1,
n_jobs = NULL,
seed = 123L,
verbose = 1L,
backend = c("cpu", "gpu", "tpu")
)
library(datasets)
X <- as.matrix(iris[, 1:4])
y <- as.integer(iris[, 5]) - 1L
(index_train <- base::sample.int(n = nrow(X),
size = floor(0.8*nrow(X)),
replace = FALSE))
#> [1] 127 41 8 52 67 2 65 135 29 20 124 80 7 45 76 25 91 94
#> [19] 150 12 148 71 121 15 83 49 107 39 115 87 35 6 129 4 74 81
#> [37] 48 130 53 146 108 112 101 14 28 119 9 33 46 64 19 98 1 92
#> [55] 147 133 68 90 75 5 100 120 103 69 61 23 149 63 125 128 110 136
#> [73] 145 18 117 72 144 11 57 56 132 96 86 36 3 138 79 118 40 109
#> [91] 37 24 38 139 66 73 10 142 85 134 59 32 143 55 54 70 31 102
#> [109] 50 89 43 113 82 116 17 141 88 62 13 34
X_train <- X[index_train, ]
y_train <- y[index_train]
X_test <- X[-index_train, ]
y_test <- y[-index_train]
obj <- sklearn$tree$DecisionTreeClassifier()
obj2 <- RandomBagClassifier(obj, n_estimators=50L,
n_hidden_features=5L)
obj2$fit(X_train, y_train)
#> RandomBagClassifier(col_sample=1.0, dropout=0.0, n_estimators=50,
#> n_hidden_features=5, obj=DecisionTreeClassifier(),
#> row_sample=1.0)
print(obj2$score(X_test, y_test))
#> [1] 0.7333333
print(obj2$predict_proba(X_test))
#> [,1] [,2] [,3]
#> [1,] 1.0000000 0.0000000 0.00
#> [2,] 1.0000000 0.0000000 0.00
#> [3,] 1.0000000 0.0000000 0.00
#> [4,] 0.8861416 0.1138584 0.00
#> [5,] 1.0000000 0.0000000 0.00
#> [6,] 0.8861416 0.1138584 0.00
#> [7,] 0.8861416 0.1138584 0.00
#> [8,] 1.0000000 0.0000000 0.00
#> [9,] 1.0000000 0.0000000 0.00
#> [10,] 0.0000000 0.4000000 0.60
#> [11,] 0.8861416 0.1138584 0.00
#> [12,] 0.0000000 0.4400000 0.56
#> [13,] 0.0000000 0.1000000 0.90
#> [14,] 0.0000000 0.0000000 1.00
#> [15,] 0.0000000 1.0000000 0.00
#> [16,] 0.0000000 1.0000000 0.00
#> [17,] 0.0000000 1.0000000 0.00
#> [18,] 0.0000000 1.0000000 0.00
#> [19,] 0.8861416 0.1138584 0.00
#> [20,] 0.0000000 0.9000000 0.10
#> [21,] 0.0000000 0.0000000 1.00
#> [22,] 0.0000000 0.0000000 1.00
#> [23,] 0.0000000 0.0000000 1.00
#> [24,] 0.0000000 0.8000000 0.20
#> [25,] 0.0000000 0.1200000 0.88
#> [26,] 0.0000000 0.0000000 1.00
#> [27,] 0.0000000 0.0000000 1.00
#> [28,] 0.0000000 0.0000000 1.00
#> [29,] 0.0000000 0.0000000 1.00
#> [30,] 0.0000000 0.0000000 1.00