CustomClassifier.Rd
Parameters' description can be found at https://techtonique.github.io/nnetsauce/
CustomClassifier(
obj,
n_hidden_features = 5L,
activation_name = "relu",
a = 0.01,
nodes_sim = "sobol",
bias = TRUE,
dropout = 0,
direct_link = TRUE,
n_clusters = 2L,
cluster_encode = TRUE,
type_clust = "kmeans",
col_sample = 1,
row_sample = 1,
seed = 123L,
backend = c("cpu", "gpu", "tpu")
)
library(datasets)
set.seed(123)
X <- as.matrix(iris[, 1:4])
y <- as.integer(iris$Species) - 1L
(index_train <- base::sample.int(n = nrow(X),
size = floor(0.8*nrow(X)),
replace = FALSE))
#> [1] 14 50 118 43 150 148 90 91 143 92 137 99 72 26 7 78 81 147
#> [19] 103 117 76 32 106 109 136 9 41 74 23 27 60 53 126 119 121 96
#> [37] 38 89 34 93 69 138 130 63 13 82 97 142 25 114 21 79 124 47
#> [55] 144 120 16 6 127 86 132 39 31 134 149 112 4 128 110 102 52 22
#> [73] 129 87 35 40 30 12 88 123 64 146 67 122 37 8 51 10 115 42
#> [91] 44 85 107 139 73 20 46 17 54 108 75 80 71 15 24 68 133 145
#> [109] 29 104 45 140 101 135 95 116 5 111 94 49
X_train <- X[index_train, ]
y_train <- y[index_train]
X_test <- X[-index_train, ]
y_test <- y[-index_train]
obj <- sklearn$tree$DecisionTreeClassifier()
obj2 <- CustomClassifier(obj)
obj2$fit(X_train, y_train)
#> Error in py_call_impl(callable, call_args$unnamed, call_args$named): AttributeError: 'list' object has no attribute 'dtype'
#> Run `reticulate::py_last_error()` for details.
print(obj2$score(X_test, y_test))
#> Error in py_call_impl(callable, call_args$unnamed, call_args$named): sklearn.exceptions.NotFittedError: This DecisionTreeClassifier instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.
#> Run `reticulate::py_last_error()` for details.