Ridge2Classifier.Rd
Parameters' description can be found at https://techtonique.github.io/nnetsauce/
Ridge2Classifier(
n_hidden_features = 5L,
activation_name = "relu",
a = 0.01,
nodes_sim = "sobol",
bias = TRUE,
dropout = 0,
direct_link = TRUE,
n_clusters = 2L,
cluster_encode = TRUE,
type_clust = "kmeans",
lambda1 = 0.1,
lambda2 = 0.1,
seed = 123L,
backend = c("cpu", "gpu", "tpu")
)
library(datasets)
X <- as.matrix(iris[, 1:4])
y <- as.integer(iris[, 5]) - 1L
(index_train <- base::sample.int(n = nrow(X),
size = floor(0.8*nrow(X)),
replace = FALSE))
#> [1] 99 35 26 13 36 28 44 54 29 93 147 116 5 56 77 92 72 126
#> [19] 75 87 94 48 15 65 143 108 10 86 42 100 41 25 130 52 37 27
#> [37] 73 83 149 34 63 131 76 79 49 112 142 71 141 45 111 114 47 1
#> [55] 124 120 95 39 109 6 32 61 80 16 2 121 78 30 145 58 113 102
#> [73] 70 68 67 46 59 129 96 82 125 97 144 106 33 146 51 140 50 128
#> [91] 110 132 88 55 18 136 62 107 148 127 64 3 119 11 81 135 98 38
#> [109] 90 31 139 123 22 17 138 89 122 115 43 118
X_train <- X[index_train, ]
y_train <- y[index_train]
X_test <- X[-index_train, ]
y_test <- y[-index_train]
obj <- Ridge2Classifier()
obj$fit(X_train, y_train)
#> Ridge2Classifier(dropout=0.0)
print(obj$score(X_test, y_test))
#> [1] 0.9
print(obj$predict_proba(X_train))
#> [,1] [,2] [,3]
#> [1,] 0.194300405 0.632274685 0.173424910
#> [2,] 0.926674339 0.054878407 0.018447254
#> [3,] 0.915425641 0.063162027 0.021412332
#> [4,] 0.924952735 0.056503670 0.018543595
#> [5,] 0.937013380 0.046862229 0.016124391
#> [6,] 0.954094902 0.032184724 0.013720374
#> [7,] 0.950698025 0.032549579 0.016752396
#> [8,] 0.115181502 0.675426177 0.209392322
#> [9,] 0.948382856 0.036852458 0.014764686
#> [10,] 0.138702378 0.631188215 0.230109408
#> [11,] 0.048930804 0.438148697 0.512920499
#> [12,] 0.043122480 0.160500710 0.796376810
#> [13,] 0.961984364 0.026507064 0.011508571
#> [14,] 0.140846432 0.572598258 0.286555310
#> [15,] 0.068277921 0.397015976 0.534706104
#> [16,] 0.126741764 0.441530708 0.431727529
#> [17,] 0.134110640 0.550160052 0.315729307
#> [18,] 0.027181075 0.130682486 0.842136439
#> [19,] 0.115927188 0.469689841 0.414382971
#> [20,] 0.085125515 0.309902259 0.604972226
#> [21,] 0.162752990 0.658482431 0.178764579
#> [22,] 0.939407966 0.045533110 0.015058924
#> [23,] 0.974722664 0.015737908 0.009539429
#> [24,] 0.198310988 0.572354845 0.229334167
#> [25,] 0.074580466 0.445922372 0.479497161
#> [26,] 0.019350611 0.156906754 0.823742635
#> [27,] 0.929371677 0.053105474 0.017522849
#> [28,] 0.158096685 0.302768422 0.539134893
#> [29,] 0.852434047 0.110186710 0.037379243
#> [30,] 0.155266332 0.585283673 0.259449995
#> [31,] 0.956270469 0.030582682 0.013146850
#> [32,] 0.947886327 0.037384618 0.014729054
#> [33,] 0.033247835 0.199658538 0.767093627
#> [34,] 0.117823721 0.330041271 0.552135008
#> [35,] 0.953060702 0.032649423 0.014289874
#> [36,] 0.946281962 0.037323921 0.016394117
#> [37,] 0.065182455 0.535873550 0.398943994
#> [38,] 0.152578711 0.614074618 0.233346671
#> [39,] 0.052954776 0.142225907 0.804819318
#> [40,] 0.980914465 0.011404822 0.007680713
#> [41,] 0.099425178 0.675606588 0.224968234
#> [42,] 0.016754053 0.158065457 0.825180490
#> [43,] 0.101451839 0.387319305 0.511228857
#> [44,] 0.120114320 0.471128759 0.408756921
#> [45,] 0.964221289 0.024145036 0.011633675
#> [46,] 0.047563865 0.344928863 0.607507271
#> [47,] 0.029666739 0.140128894 0.830204367
#> [48,] 0.117104870 0.320644675 0.562250455
#> [49,] 0.025472907 0.127676485 0.846850608
#> [50,] 0.966095570 0.021643290 0.012261140
#> [51,] 0.057828325 0.202331537 0.739840138
#> [52,] 0.065040929 0.507038761 0.427920309
#> [53,] 0.969371930 0.020439126 0.010188944
#> [54,] 0.955685638 0.031271827 0.013042535
#> [55,] 0.063047249 0.406295874 0.530656877
#> [56,] 0.059703553 0.624508743 0.315787704
#> [57,] 0.146010191 0.614642175 0.239347634
#> [58,] 0.930227741 0.052532885 0.017239374
#> [59,] 0.030569635 0.347275429 0.622154936
#> [60,] 0.969527590 0.018860661 0.011611749
#> [61,] 0.943104760 0.038944254 0.017950985
#> [62,] 0.118362529 0.691810268 0.189827203
#> [63,] 0.177278820 0.638772184 0.183948996
#> [64,] 0.982010564 0.008967929 0.009021507
#> [65,] 0.920306113 0.059667145 0.020026742
#> [66,] 0.023656275 0.108071262 0.868272464
#> [67,] 0.061024983 0.281483175 0.657491842
#> [68,] 0.936287214 0.047530771 0.016182016
#> [69,] 0.023538655 0.091161316 0.885300029
#> [70,] 0.181953348 0.644126244 0.173920408
#> [71,] 0.031906823 0.177413262 0.790679915
#> [72,] 0.074580466 0.445922372 0.479497161
#> [73,] 0.148205995 0.654349212 0.197444793
#> [74,] 0.161655293 0.622730910 0.215613797
#> [75,] 0.158755651 0.484860677 0.356383673
#> [76,] 0.919147899 0.060310306 0.020541795
#> [77,] 0.094734679 0.427775967 0.477489353
#> [78,] 0.038013265 0.260455881 0.701530854
#> [79,] 0.188758536 0.531790220 0.279451244
#> [80,] 0.150539785 0.663754706 0.185705508
#> [81,] 0.036113309 0.130855634 0.833031057
#> [82,] 0.165814117 0.550044455 0.284141428
#> [83,] 0.023617127 0.109185095 0.867197778
#> [84,] 0.009739103 0.081986774 0.908274123
#> [85,] 0.979997921 0.012820454 0.007181625
#> [86,] 0.032211944 0.170500216 0.797287840
#> [87,] 0.079492033 0.265614266 0.654893701
#> [88,] 0.032012811 0.154789349 0.813197840
#> [89,] 0.943088044 0.041571674 0.015340282
#> [90,] 0.087888293 0.346327749 0.565783958
#> [91,] 0.013584862 0.041583884 0.944831254
#> [92,] 0.011564159 0.030809029 0.957626812
#> [93,] 0.074597577 0.623630994 0.301771429
#> [94,] 0.081666200 0.430212591 0.488121209
#> [95,] 0.954693719 0.031475955 0.013830326
#> [96,] 0.008942649 0.072670684 0.918386667
#> [97,] 0.148865425 0.464275396 0.386859179
#> [98,] 0.125451688 0.653177841 0.221370471
#> [99,] 0.050296534 0.240490529 0.709212937
#> [100,] 0.074551683 0.396714659 0.528733658
#> [101,] 0.113942942 0.468434585 0.417622473
#> [102,] 0.939290243 0.045560404 0.015149354
#> [103,] 0.005388028 0.091583956 0.903028015
#> [104,] 0.963530445 0.024498751 0.011970804
#> [105,] 0.142322878 0.665238458 0.192438664
#> [106,] 0.068020767 0.514738567 0.417240665
#> [107,] 0.128442644 0.494156908 0.377400448
#> [108,] 0.963504133 0.025950348 0.010545519
#> [109,] 0.137605875 0.653208329 0.209185796
#> [110,] 0.926996114 0.054649388 0.018354498
#> [111,] 0.096858262 0.364967381 0.538174357
#> [112,] 0.008854332 0.102807968 0.888337700
#> [113,] 0.963874653 0.023766777 0.012358570
#> [114,] 0.971629084 0.017787402 0.010583514
#> [115,] 0.060797777 0.250412094 0.688790129
#> [116,] 0.192934454 0.528482492 0.278583054
#> [117,] 0.090220362 0.433082195 0.476697443
#> [118,] 0.055019418 0.308718972 0.636261610
#> [119,] 0.943242551 0.042766081 0.013991368
#> [120,] 0.009568034 0.026483761 0.963948205