Parameters' description can be found at https://techtonique.github.io/nnetsauce/

Ridge2MultitaskClassifier(
  n_hidden_features = 5L,
  activation_name = "relu",
  a = 0.01,
  nodes_sim = "sobol",
  bias = TRUE,
  dropout = 0,
  n_clusters = 2L,
  cluster_encode = TRUE,
  type_clust = "kmeans",
  lambda1 = 0.1,
  lambda2 = 0.1,
  seed = 123L,
  backend = c("cpu", "gpu", "tpu")
)

Examples


# Example 1 -----

library(datasets)

X <- as.matrix(iris[, 1:4])
y <- as.integer(iris[, 5]) - 1L

(index_train <- base::sample.int(n = nrow(X),
                                 size = floor(0.8*nrow(X)),
                                 replace = FALSE))
#>   [1]  89 143 123  35  30  69 130  23  21  19 137 104  32  48  15 147  55 125
#>  [19]  17  71  77  13  78  68  25  29  53  16  41 112   9  37  54  86 110 108
#>  [37] 117  24  14 146 109  62  40  95  39 101  72 121 140  65 149  26  46 100
#>  [55] 107  63 142  81 106 131  45  47  76  98  83  60 114 132 120  97  31  38
#>  [73]  33  85 139   5  96 122 119  87 148   6  52  91   8  43 134 127 113  75
#>  [91] 103 128   1  88 135   4  56  67  18  10  12 102   3  99  50  57  66  84
#> [109]  61 141  79 138   2 105  44  74 150 111  73  42
X_train <- X[index_train, ]
y_train <- y[index_train]
X_test <- X[-index_train, ]
y_test <- y[-index_train]

obj <- Ridge2MultitaskClassifier()
obj$fit(X_train, y_train)
#> Ridge2MultitaskClassifier(dropout=0.0)
print(obj$score(X_test, y_test))
#> [1] 1
print(obj$predict_proba(X_train))
#>             [,1]      [,2]      [,3]
#>   [1,] 0.2119028 0.5679791 0.2201180
#>   [2,] 0.2222580 0.2751043 0.5026377
#>   [3,] 0.1960123 0.1687949 0.6351928
#>   [4,] 0.5725322 0.2139629 0.2135049
#>   [5,] 0.5723656 0.1998084 0.2278259
#>   [6,] 0.2282339 0.4469197 0.3248464
#>   [7,] 0.2297084 0.3473757 0.4229159
#>   [8,] 0.5724206 0.2412066 0.1863728
#>   [9,] 0.5721475 0.2169334 0.2109191
#>  [10,] 0.5724972 0.2117283 0.2157745
#>  [11,] 0.2004248 0.1811702 0.6184050
#>  [12,] 0.2220180 0.2725159 0.5054661
#>  [13,] 0.5716951 0.1975446 0.2307603
#>  [14,] 0.5732507 0.2108187 0.2159307
#>  [15,] 0.5618873 0.2834096 0.1547031
#>  [16,] 0.2256459 0.3003175 0.4740366
#>  [17,] 0.2241065 0.4911452 0.2847483
#>  [18,] 0.2162416 0.2392536 0.5445047
#>  [19,] 0.5726677 0.2258854 0.2014470
#>  [20,] 0.2304335 0.3792600 0.3903065
#>  [21,] 0.2223732 0.5053236 0.2723032
#>  [22,] 0.5714159 0.2360257 0.1925583
#>  [23,] 0.2302895 0.4072438 0.3624667
#>  [24,] 0.1986058 0.6264954 0.1748988
#>  [25,] 0.5699465 0.1824273 0.2476262
#>  [26,] 0.5718675 0.2323467 0.1957858
#>  [27,] 0.2241075 0.4926812 0.2832113
#>  [28,] 0.5717220 0.2338188 0.1944593
#>  [29,] 0.5733807 0.2150763 0.2115430
#>  [30,] 0.2231156 0.2798937 0.4969907
#>  [31,] 0.5729196 0.2043697 0.2227107
#>  [32,] 0.5691438 0.2517189 0.1791374
#>  [33,] 0.2242887 0.4923884 0.2833229
#>  [34,] 0.2237639 0.4949676 0.2812685
#>  [35,] 0.1932464 0.1607491 0.6460045
#>  [36,] 0.2154702 0.2361702 0.5483596
#>  [37,] 0.2260205 0.3019460 0.4720335
#>  [38,] 0.5631233 0.1593461 0.2775306
#>  [39,] 0.5718459 0.2410642 0.1870899
#>  [40,] 0.2159335 0.2379768 0.5460897
#>  [41,] 0.2166355 0.2429980 0.5403665
#>  [42,] 0.2188693 0.5289117 0.2522190
#>  [43,] 0.5726500 0.2210817 0.2062682
#>  [44,] 0.2192879 0.5262520 0.2544601
#>  [45,] 0.5733923 0.2114978 0.2151098
#>  [46,] 0.1779933 0.1274569 0.6945498
#>  [47,] 0.2042964 0.6033677 0.1923359
#>  [48,] 0.2072531 0.2028786 0.5898683
#>  [49,] 0.2223745 0.2726852 0.5049403
#>  [50,] 0.2018468 0.6136139 0.1845393
#>  [51,] 0.2107552 0.2160620 0.5731828
#>  [52,] 0.5720786 0.2101994 0.2177220
#>  [53,] 0.5721635 0.2008058 0.2270307
#>  [54,] 0.2137379 0.5586346 0.2276275
#>  [55,] 0.2250271 0.2933798 0.4815931
#>  [56,] 0.1979689 0.6285376 0.1734935
#>  [57,] 0.2212705 0.2654090 0.5133205
#>  [58,] 0.2110894 0.5740879 0.2148227
#>  [59,] 0.1954895 0.1671904 0.6373201
#>  [60,] 0.2157353 0.2372684 0.5469964
#>  [61,] 0.5663306 0.1669411 0.2667283
#>  [62,] 0.5735611 0.2141180 0.2123209
#>  [63,] 0.2126713 0.5648437 0.2224850
#>  [64,] 0.2105998 0.5746184 0.2147818
#>  [65,] 0.2041034 0.6043031 0.1915935
#>  [66,] 0.2244610 0.4891008 0.2864383
#>  [67,] 0.2175224 0.2482361 0.5342415
#>  [68,] 0.2214559 0.2642088 0.5143352
#>  [69,] 0.2300238 0.3637993 0.4061770
#>  [70,] 0.2139885 0.5571489 0.2288626
#>  [71,] 0.5722843 0.2032758 0.2244399
#>  [72,] 0.5718442 0.2395466 0.1886092
#>  [73,] 0.5681514 0.2600805 0.1717681
#>  [74,] 0.2277734 0.4510310 0.3211956
#>  [75,] 0.2302664 0.3705014 0.3992322
#>  [76,] 0.5731031 0.2252229 0.2016739
#>  [77,] 0.2073103 0.5899337 0.2027560
#>  [78,] 0.2221278 0.2741013 0.5037709
#>  [79,] 0.1649823 0.1041407 0.7308771
#>  [80,] 0.2218397 0.5100532 0.2681071
#>  [81,] 0.2252013 0.2942849 0.4805138
#>  [82,] 0.5718953 0.1964048 0.2317000
#>  [83,] 0.2190436 0.5288448 0.2521116
#>  [84,] 0.2225403 0.5038442 0.2736155
#>  [85,] 0.5728942 0.2174884 0.2096175
#>  [86,] 0.5736865 0.2112509 0.2150626
#>  [87,] 0.2300596 0.4070733 0.3628671
#>  [88,] 0.2301785 0.3673792 0.4024423
#>  [89,] 0.2182543 0.2495078 0.5322379
#>  [90,] 0.2083976 0.5852233 0.2063791
#>  [91,] 0.2111244 0.2174724 0.5714032
#>  [92,] 0.2300826 0.3620010 0.4079165
#>  [93,] 0.5725135 0.2287685 0.1987180
#>  [94,] 0.2193321 0.5263701 0.2542978
#>  [95,] 0.2287820 0.3373035 0.4339145
#>  [96,] 0.5727483 0.2037028 0.2235489
#>  [97,] 0.2213263 0.5119962 0.2666776
#>  [98,] 0.2263354 0.4686155 0.3050491
#>  [99,] 0.5731331 0.2112669 0.2155999
#> [100,] 0.5717260 0.2317939 0.1964801
#> [101,] 0.5728176 0.2031882 0.2239942
#> [102,] 0.2222580 0.2751043 0.5026377
#> [103,] 0.5730339 0.2217224 0.2052437
#> [104,] 0.1988823 0.6283245 0.1727932
#> [105,] 0.5724913 0.2248716 0.2026371
#> [106,] 0.2259167 0.4758978 0.2981854
#> [107,] 0.2102356 0.5771984 0.2125660
#> [108,] 0.2297419 0.3574705 0.4127876
#> [109,] 0.2073400 0.5913190 0.2013410
#> [110,] 0.2006287 0.1817969 0.6175744
#> [111,] 0.2248077 0.4842701 0.2909222
#> [112,] 0.2261909 0.3033972 0.4704119
#> [113,] 0.5723338 0.2215386 0.2061277
#> [114,] 0.2014084 0.1845818 0.6140098
#> [115,] 0.5594281 0.1486497 0.2919222
#> [116,] 0.2183261 0.5322410 0.2494329
#> [117,] 0.2279285 0.3234957 0.4485757
#> [118,] 0.2280077 0.3208903 0.4511020
#> [119,] 0.2298329 0.4150179 0.3551492
#> [120,] 0.5715354 0.1987219 0.2297427