Parameters description can be found at https://techtonique.github.io/nnetsauce/

MultitaskClassifier(
  obj,
  n_hidden_features = 5L,
  activation_name = "relu",
  a = 0.01,
  nodes_sim = "sobol",
  bias = TRUE,
  dropout = 0,
  direct_link = TRUE,
  n_clusters = 2L,
  cluster_encode = TRUE,
  type_clust = "kmeans",
  col_sample = 1,
  row_sample = 1,
  seed = 123L,
  backend = c("cpu", "gpu", "tpu")
)

Examples


library(datasets)

X <- as.matrix(iris[, 1:4])
y <- as.integer(iris[, 5]) - 1L

(index_train <- base::sample.int(n = nrow(X),
                                 size = floor(0.8*nrow(X)),
                                 replace = FALSE))
#>   [1] 137 139  43 115  55  57 126  38  84  63  78  70 130 120  75  21  87  72
#>  [19]  59  81 146   6 128 111  28  32  49  99  92 127 142  65  69 136  47  20
#>  [37]   2  62 112 141 113  10 132 124   7  61 143  36  23  34  88   4  54  29
#>  [55]  58 134 117  56  51  50  48   3  33 101  66  64  40  96 147 140  25  71
#>  [73] 150 105  22  93 100  85  53  31  46  24  60  30 102  17  26 108 145 110
#>  [91] 148  86  73 103  45   8  15  77  94  89  39  74 104  52  83 144 149 119
#> [109] 121  90 109  41 138   1  16 114  14  12  44  91
X_train <- X[index_train, ]
y_train <- y[index_train]
X_test <- X[-index_train, ]
y_test <- y[-index_train]

obj <- sklearn$linear_model$LinearRegression()
obj2 <- MultitaskClassifier(obj)
obj2$fit(X_train, y_train)
#> MultitaskClassifier(col_sample=1.0, dropout=0.0, obj=LinearRegression(),
#>                     row_sample=1.0)
print(obj2$score(X_test, y_test))
#> [1] 1
print(obj2$predict_proba(X_test))
#>            [,1]      [,2]      [,3]
#>  [1,] 0.4223188 0.2966922 0.2809890
#>  [2,] 0.4223188 0.2828842 0.2947970
#>  [3,] 0.4223188 0.2996851 0.2779961
#>  [4,] 0.4223188 0.3008353 0.2768459
#>  [5,] 0.4223188 0.2863899 0.2912913
#>  [6,] 0.4223188 0.2842122 0.2934690
#>  [7,] 0.4223188 0.2634565 0.3142247
#>  [8,] 0.4223188 0.2866407 0.2910405
#>  [9,] 0.4223188 0.3088012 0.2688800
#> [10,] 0.4223188 0.2750425 0.3026387
#> [11,] 0.2869856 0.3855759 0.3274385
#> [12,] 0.2906424 0.4442556 0.2651020
#> [13,] 0.2882608 0.4134792 0.2982599
#> [14,] 0.2871063 0.3892762 0.3236174
#> [15,] 0.2910604 0.4486035 0.2603361
#> [16,] 0.2895694 0.4319646 0.2784659
#> [17,] 0.2881746 0.4120433 0.2997821
#> [18,] 0.2884411 0.4163694 0.2951895
#> [19,] 0.2885608 0.4182123 0.2932270
#> [20,] 0.2905598 0.2660690 0.4433712
#> [21,] 0.2867108 0.3388675 0.3744217
#> [22,] 0.2884160 0.2956082 0.4159758
#> [23,] 0.2899451 0.2735723 0.4364826
#> [24,] 0.2872885 0.3185449 0.3941666
#> [25,] 0.2904965 0.2668173 0.4426862
#> [26,] 0.2880809 0.3014815 0.4104376
#> [27,] 0.2888104 0.2893011 0.4218884
#> [28,] 0.2882207 0.2989644 0.4128149
#> [29,] 0.2894152 0.2805548 0.4300300
#> [30,] 0.2866530 0.3423594 0.3709876