Parameters' description can be found at https://techtonique.github.io/nnetsauce/

Ridge2Classifier(
  n_hidden_features = 5L,
  activation_name = "relu",
  a = 0.01,
  nodes_sim = "sobol",
  bias = TRUE,
  dropout = 0,
  direct_link = TRUE,
  n_clusters = 2L,
  cluster_encode = TRUE,
  type_clust = "kmeans",
  lambda1 = 0.1,
  lambda2 = 0.1,
  seed = 123L,
  backend = c("cpu", "gpu", "tpu")
)

Examples


library(datasets)

X <- as.matrix(iris[, 1:4])
y <- as.integer(iris[, 5]) - 1L

(index_train <- base::sample.int(n = nrow(X),
                                 size = floor(0.8*nrow(X)),
                                 replace = FALSE))
#>   [1]  99  35  26  13  36  28  44  54  29  93 147 116   5  56  77  92  72 126
#>  [19]  75  87  94  48  15  65 143 108  10  86  42 100  41  25 130  52  37  27
#>  [37]  73  83 149  34  63 131  76  79  49 112 142  71 141  45 111 114  47   1
#>  [55] 124 120  95  39 109   6  32  61  80  16   2 121  78  30 145  58 113 102
#>  [73]  70  68  67  46  59 129  96  82 125  97 144 106  33 146  51 140  50 128
#>  [91] 110 132  88  55  18 136  62 107 148 127  64   3 119  11  81 135  98  38
#> [109]  90  31 139 123  22  17 138  89 122 115  43 118
X_train <- X[index_train, ]
y_train <- y[index_train]
X_test <- X[-index_train, ]
y_test <- y[-index_train]

obj <- Ridge2Classifier()
obj$fit(X_train, y_train)
#> Ridge2Classifier(dropout=0.0)
print(obj$score(X_test, y_test))
#> [1] 0.9
print(obj$predict_proba(X_train))
#>               [,1]        [,2]        [,3]
#>   [1,] 0.194300405 0.632274685 0.173424910
#>   [2,] 0.926674339 0.054878407 0.018447254
#>   [3,] 0.915425641 0.063162027 0.021412332
#>   [4,] 0.924952735 0.056503670 0.018543595
#>   [5,] 0.937013380 0.046862229 0.016124391
#>   [6,] 0.954094902 0.032184724 0.013720374
#>   [7,] 0.950698025 0.032549579 0.016752396
#>   [8,] 0.115181502 0.675426177 0.209392322
#>   [9,] 0.948382856 0.036852458 0.014764686
#>  [10,] 0.138702378 0.631188215 0.230109408
#>  [11,] 0.048930804 0.438148697 0.512920499
#>  [12,] 0.043122480 0.160500710 0.796376810
#>  [13,] 0.961984364 0.026507064 0.011508571
#>  [14,] 0.140846432 0.572598258 0.286555310
#>  [15,] 0.068277921 0.397015976 0.534706104
#>  [16,] 0.126741764 0.441530708 0.431727529
#>  [17,] 0.134110640 0.550160052 0.315729307
#>  [18,] 0.027181075 0.130682486 0.842136439
#>  [19,] 0.115927188 0.469689841 0.414382971
#>  [20,] 0.085125515 0.309902259 0.604972226
#>  [21,] 0.162752990 0.658482431 0.178764579
#>  [22,] 0.939407966 0.045533110 0.015058924
#>  [23,] 0.974722664 0.015737908 0.009539429
#>  [24,] 0.198310988 0.572354845 0.229334167
#>  [25,] 0.074580466 0.445922372 0.479497161
#>  [26,] 0.019350611 0.156906754 0.823742635
#>  [27,] 0.929371677 0.053105474 0.017522849
#>  [28,] 0.158096685 0.302768422 0.539134893
#>  [29,] 0.852434047 0.110186710 0.037379243
#>  [30,] 0.155266332 0.585283673 0.259449995
#>  [31,] 0.956270469 0.030582682 0.013146850
#>  [32,] 0.947886327 0.037384618 0.014729054
#>  [33,] 0.033247835 0.199658538 0.767093627
#>  [34,] 0.117823721 0.330041271 0.552135008
#>  [35,] 0.953060702 0.032649423 0.014289874
#>  [36,] 0.946281962 0.037323921 0.016394117
#>  [37,] 0.065182455 0.535873550 0.398943994
#>  [38,] 0.152578711 0.614074618 0.233346671
#>  [39,] 0.052954776 0.142225907 0.804819318
#>  [40,] 0.980914465 0.011404822 0.007680713
#>  [41,] 0.099425178 0.675606588 0.224968234
#>  [42,] 0.016754053 0.158065457 0.825180490
#>  [43,] 0.101451839 0.387319305 0.511228857
#>  [44,] 0.120114320 0.471128759 0.408756921
#>  [45,] 0.964221289 0.024145036 0.011633675
#>  [46,] 0.047563865 0.344928863 0.607507271
#>  [47,] 0.029666739 0.140128894 0.830204367
#>  [48,] 0.117104870 0.320644675 0.562250455
#>  [49,] 0.025472907 0.127676485 0.846850608
#>  [50,] 0.966095570 0.021643290 0.012261140
#>  [51,] 0.057828325 0.202331537 0.739840138
#>  [52,] 0.065040929 0.507038761 0.427920309
#>  [53,] 0.969371930 0.020439126 0.010188944
#>  [54,] 0.955685638 0.031271827 0.013042535
#>  [55,] 0.063047249 0.406295874 0.530656877
#>  [56,] 0.059703553 0.624508743 0.315787704
#>  [57,] 0.146010191 0.614642175 0.239347634
#>  [58,] 0.930227741 0.052532885 0.017239374
#>  [59,] 0.030569635 0.347275429 0.622154936
#>  [60,] 0.969527590 0.018860661 0.011611749
#>  [61,] 0.943104760 0.038944254 0.017950985
#>  [62,] 0.118362529 0.691810268 0.189827203
#>  [63,] 0.177278820 0.638772184 0.183948996
#>  [64,] 0.982010564 0.008967929 0.009021507
#>  [65,] 0.920306113 0.059667145 0.020026742
#>  [66,] 0.023656275 0.108071262 0.868272464
#>  [67,] 0.061024983 0.281483175 0.657491842
#>  [68,] 0.936287214 0.047530771 0.016182016
#>  [69,] 0.023538655 0.091161316 0.885300029
#>  [70,] 0.181953348 0.644126244 0.173920408
#>  [71,] 0.031906823 0.177413262 0.790679915
#>  [72,] 0.074580466 0.445922372 0.479497161
#>  [73,] 0.148205995 0.654349212 0.197444793
#>  [74,] 0.161655293 0.622730910 0.215613797
#>  [75,] 0.158755651 0.484860677 0.356383673
#>  [76,] 0.919147899 0.060310306 0.020541795
#>  [77,] 0.094734679 0.427775967 0.477489353
#>  [78,] 0.038013265 0.260455881 0.701530854
#>  [79,] 0.188758536 0.531790220 0.279451244
#>  [80,] 0.150539785 0.663754706 0.185705508
#>  [81,] 0.036113309 0.130855634 0.833031057
#>  [82,] 0.165814117 0.550044455 0.284141428
#>  [83,] 0.023617127 0.109185095 0.867197778
#>  [84,] 0.009739103 0.081986774 0.908274123
#>  [85,] 0.979997921 0.012820454 0.007181625
#>  [86,] 0.032211944 0.170500216 0.797287840
#>  [87,] 0.079492033 0.265614266 0.654893701
#>  [88,] 0.032012811 0.154789349 0.813197840
#>  [89,] 0.943088044 0.041571674 0.015340282
#>  [90,] 0.087888293 0.346327749 0.565783958
#>  [91,] 0.013584862 0.041583884 0.944831254
#>  [92,] 0.011564159 0.030809029 0.957626812
#>  [93,] 0.074597577 0.623630994 0.301771429
#>  [94,] 0.081666200 0.430212591 0.488121209
#>  [95,] 0.954693719 0.031475955 0.013830326
#>  [96,] 0.008942649 0.072670684 0.918386667
#>  [97,] 0.148865425 0.464275396 0.386859179
#>  [98,] 0.125451688 0.653177841 0.221370471
#>  [99,] 0.050296534 0.240490529 0.709212937
#> [100,] 0.074551683 0.396714659 0.528733658
#> [101,] 0.113942942 0.468434585 0.417622473
#> [102,] 0.939290243 0.045560404 0.015149354
#> [103,] 0.005388028 0.091583956 0.903028015
#> [104,] 0.963530445 0.024498751 0.011970804
#> [105,] 0.142322878 0.665238458 0.192438664
#> [106,] 0.068020767 0.514738567 0.417240665
#> [107,] 0.128442644 0.494156908 0.377400448
#> [108,] 0.963504133 0.025950348 0.010545519
#> [109,] 0.137605875 0.653208329 0.209185796
#> [110,] 0.926996114 0.054649388 0.018354498
#> [111,] 0.096858262 0.364967381 0.538174357
#> [112,] 0.008854332 0.102807968 0.888337700
#> [113,] 0.963874653 0.023766777 0.012358570
#> [114,] 0.971629084 0.017787402 0.010583514
#> [115,] 0.060797777 0.250412094 0.688790129
#> [116,] 0.192934454 0.528482492 0.278583054
#> [117,] 0.090220362 0.433082195 0.476697443
#> [118,] 0.055019418 0.308718972 0.636261610
#> [119,] 0.943242551 0.042766081 0.013991368
#> [120,] 0.009568034 0.026483761 0.963948205