Ridge2MultitaskClassifier.Rd
Parameters' description can be found at https://techtonique.github.io/nnetsauce/
Ridge2MultitaskClassifier(
n_hidden_features = 5L,
activation_name = "relu",
a = 0.01,
nodes_sim = "sobol",
bias = TRUE,
dropout = 0,
n_clusters = 2L,
cluster_encode = TRUE,
type_clust = "kmeans",
lambda1 = 0.1,
lambda2 = 0.1,
seed = 123L,
backend = c("cpu", "gpu", "tpu")
)
# Example 1 -----
library(datasets)
X <- as.matrix(iris[, 1:4])
y <- as.integer(iris[, 5]) - 1L
(index_train <- base::sample.int(n = nrow(X),
size = floor(0.8*nrow(X)),
replace = FALSE))
#> [1] 89 143 123 35 30 69 130 23 21 19 137 104 32 48 15 147 55 125
#> [19] 17 71 77 13 78 68 25 29 53 16 41 112 9 37 54 86 110 108
#> [37] 117 24 14 146 109 62 40 95 39 101 72 121 140 65 149 26 46 100
#> [55] 107 63 142 81 106 131 45 47 76 98 83 60 114 132 120 97 31 38
#> [73] 33 85 139 5 96 122 119 87 148 6 52 91 8 43 134 127 113 75
#> [91] 103 128 1 88 135 4 56 67 18 10 12 102 3 99 50 57 66 84
#> [109] 61 141 79 138 2 105 44 74 150 111 73 42
X_train <- X[index_train, ]
y_train <- y[index_train]
X_test <- X[-index_train, ]
y_test <- y[-index_train]
obj <- Ridge2MultitaskClassifier()
obj$fit(X_train, y_train)
#> Ridge2MultitaskClassifier(dropout=0.0)
print(obj$score(X_test, y_test))
#> [1] 1
print(obj$predict_proba(X_train))
#> [,1] [,2] [,3]
#> [1,] 0.2119028 0.5679791 0.2201180
#> [2,] 0.2222580 0.2751043 0.5026377
#> [3,] 0.1960123 0.1687949 0.6351928
#> [4,] 0.5725322 0.2139629 0.2135049
#> [5,] 0.5723656 0.1998084 0.2278259
#> [6,] 0.2282339 0.4469197 0.3248464
#> [7,] 0.2297084 0.3473757 0.4229159
#> [8,] 0.5724206 0.2412066 0.1863728
#> [9,] 0.5721475 0.2169334 0.2109191
#> [10,] 0.5724972 0.2117283 0.2157745
#> [11,] 0.2004248 0.1811702 0.6184050
#> [12,] 0.2220180 0.2725159 0.5054661
#> [13,] 0.5716951 0.1975446 0.2307603
#> [14,] 0.5732507 0.2108187 0.2159307
#> [15,] 0.5618873 0.2834096 0.1547031
#> [16,] 0.2256459 0.3003175 0.4740366
#> [17,] 0.2241065 0.4911452 0.2847483
#> [18,] 0.2162416 0.2392536 0.5445047
#> [19,] 0.5726677 0.2258854 0.2014470
#> [20,] 0.2304335 0.3792600 0.3903065
#> [21,] 0.2223732 0.5053236 0.2723032
#> [22,] 0.5714159 0.2360257 0.1925583
#> [23,] 0.2302895 0.4072438 0.3624667
#> [24,] 0.1986058 0.6264954 0.1748988
#> [25,] 0.5699465 0.1824273 0.2476262
#> [26,] 0.5718675 0.2323467 0.1957858
#> [27,] 0.2241075 0.4926812 0.2832113
#> [28,] 0.5717220 0.2338188 0.1944593
#> [29,] 0.5733807 0.2150763 0.2115430
#> [30,] 0.2231156 0.2798937 0.4969907
#> [31,] 0.5729196 0.2043697 0.2227107
#> [32,] 0.5691438 0.2517189 0.1791374
#> [33,] 0.2242887 0.4923884 0.2833229
#> [34,] 0.2237639 0.4949676 0.2812685
#> [35,] 0.1932464 0.1607491 0.6460045
#> [36,] 0.2154702 0.2361702 0.5483596
#> [37,] 0.2260205 0.3019460 0.4720335
#> [38,] 0.5631233 0.1593461 0.2775306
#> [39,] 0.5718459 0.2410642 0.1870899
#> [40,] 0.2159335 0.2379768 0.5460897
#> [41,] 0.2166355 0.2429980 0.5403665
#> [42,] 0.2188693 0.5289117 0.2522190
#> [43,] 0.5726500 0.2210817 0.2062682
#> [44,] 0.2192879 0.5262520 0.2544601
#> [45,] 0.5733923 0.2114978 0.2151098
#> [46,] 0.1779933 0.1274569 0.6945498
#> [47,] 0.2042964 0.6033677 0.1923359
#> [48,] 0.2072531 0.2028786 0.5898683
#> [49,] 0.2223745 0.2726852 0.5049403
#> [50,] 0.2018468 0.6136139 0.1845393
#> [51,] 0.2107552 0.2160620 0.5731828
#> [52,] 0.5720786 0.2101994 0.2177220
#> [53,] 0.5721635 0.2008058 0.2270307
#> [54,] 0.2137379 0.5586346 0.2276275
#> [55,] 0.2250271 0.2933798 0.4815931
#> [56,] 0.1979689 0.6285376 0.1734935
#> [57,] 0.2212705 0.2654090 0.5133205
#> [58,] 0.2110894 0.5740879 0.2148227
#> [59,] 0.1954895 0.1671904 0.6373201
#> [60,] 0.2157353 0.2372684 0.5469964
#> [61,] 0.5663306 0.1669411 0.2667283
#> [62,] 0.5735611 0.2141180 0.2123209
#> [63,] 0.2126713 0.5648437 0.2224850
#> [64,] 0.2105998 0.5746184 0.2147818
#> [65,] 0.2041034 0.6043031 0.1915935
#> [66,] 0.2244610 0.4891008 0.2864383
#> [67,] 0.2175224 0.2482361 0.5342415
#> [68,] 0.2214559 0.2642088 0.5143352
#> [69,] 0.2300238 0.3637993 0.4061770
#> [70,] 0.2139885 0.5571489 0.2288626
#> [71,] 0.5722843 0.2032758 0.2244399
#> [72,] 0.5718442 0.2395466 0.1886092
#> [73,] 0.5681514 0.2600805 0.1717681
#> [74,] 0.2277734 0.4510310 0.3211956
#> [75,] 0.2302664 0.3705014 0.3992322
#> [76,] 0.5731031 0.2252229 0.2016739
#> [77,] 0.2073103 0.5899337 0.2027560
#> [78,] 0.2221278 0.2741013 0.5037709
#> [79,] 0.1649823 0.1041407 0.7308771
#> [80,] 0.2218397 0.5100532 0.2681071
#> [81,] 0.2252013 0.2942849 0.4805138
#> [82,] 0.5718953 0.1964048 0.2317000
#> [83,] 0.2190436 0.5288448 0.2521116
#> [84,] 0.2225403 0.5038442 0.2736155
#> [85,] 0.5728942 0.2174884 0.2096175
#> [86,] 0.5736865 0.2112509 0.2150626
#> [87,] 0.2300596 0.4070733 0.3628671
#> [88,] 0.2301785 0.3673792 0.4024423
#> [89,] 0.2182543 0.2495078 0.5322379
#> [90,] 0.2083976 0.5852233 0.2063791
#> [91,] 0.2111244 0.2174724 0.5714032
#> [92,] 0.2300826 0.3620010 0.4079165
#> [93,] 0.5725135 0.2287685 0.1987180
#> [94,] 0.2193321 0.5263701 0.2542978
#> [95,] 0.2287820 0.3373035 0.4339145
#> [96,] 0.5727483 0.2037028 0.2235489
#> [97,] 0.2213263 0.5119962 0.2666776
#> [98,] 0.2263354 0.4686155 0.3050491
#> [99,] 0.5731331 0.2112669 0.2155999
#> [100,] 0.5717260 0.2317939 0.1964801
#> [101,] 0.5728176 0.2031882 0.2239942
#> [102,] 0.2222580 0.2751043 0.5026377
#> [103,] 0.5730339 0.2217224 0.2052437
#> [104,] 0.1988823 0.6283245 0.1727932
#> [105,] 0.5724913 0.2248716 0.2026371
#> [106,] 0.2259167 0.4758978 0.2981854
#> [107,] 0.2102356 0.5771984 0.2125660
#> [108,] 0.2297419 0.3574705 0.4127876
#> [109,] 0.2073400 0.5913190 0.2013410
#> [110,] 0.2006287 0.1817969 0.6175744
#> [111,] 0.2248077 0.4842701 0.2909222
#> [112,] 0.2261909 0.3033972 0.4704119
#> [113,] 0.5723338 0.2215386 0.2061277
#> [114,] 0.2014084 0.1845818 0.6140098
#> [115,] 0.5594281 0.1486497 0.2919222
#> [116,] 0.2183261 0.5322410 0.2494329
#> [117,] 0.2279285 0.3234957 0.4485757
#> [118,] 0.2280077 0.3208903 0.4511020
#> [119,] 0.2298329 0.4150179 0.3551492
#> [120,] 0.5715354 0.1987219 0.2297427