1 - With matrix inversion
## Loading required package: cclust
## Loading required package: memoise
# 0 - dataset -----
X <- as.matrix(mtcars[,-1])
y <- mtcars$mpg
(n <- dim(X)[1]); (p <- dim(X)[2])
## [1] 32
## [1] 10
## [1] 1 2 6 7 8 9 10 12 13 14 16 18 19 21 22 23 25 26 27 30 31
(idx_comp <- generics::setdiff(1:n, idx_train))
## [1] 3 4 5 11 15 17 20 24 28 29 32
(idx_validation <- base::sample(x = idx_comp,
size = floor(length(idx_comp)/2),
replace = FALSE))
## [1] 15 28 29 17 32
(idx_test <- generics::setdiff(idx_comp, idx_validation))
## [1] 3 4 5 11 20 24
## integer(0)
## integer(0)
X_train <- X[idx_train,]
y_train <- y[idx_train]
X_validation <- X[idx_validation,]
y_validation <- y[idx_validation]
X_test <- X[idx_test,]
y_test <- y[idx_test]
# 2 - Fit the model -----
# choose the 'best' regularization parameter
# (there are many other ways, and there are also
# other parameters, type 'head(bayesianrvfl::fit_rvfl)')
obj_GCV <- bayesianrvfl::fit_rvfl(x = X_train, y = y_train)
(best_lambda <- obj_GCV$lambda[which.min(obj_GCV$GCV)])
## [1] 12.9155
fit_obj <- bayesianrvfl::fit_rvfl(x = X_train,
y = y_train,
method = "solve",
lambda = best_lambda,
compute_Sigma = TRUE)
print(fit_obj)
## $coef
## [,1]
## cyl -0.36321389
## disp -0.41963495
## hp -0.47508218
## drat 0.55766798
## wt -0.54482177
## qsec 0.09679212
## vs 0.32281808
## am 0.48049052
## gear 0.25888698
## carb -0.47461167
## h1 -0.32480888
## h2 0.75205757
## h3 -0.35831420
## h4 0.04812892
## h5 0.72081638
##
## $Dn
## cyl disp hp drat wt
## cyl 0.0641456316 -0.0052760715 -0.0062410222 0.005822304 -0.0024686403
## disp -0.0052760715 0.0622811372 -0.0074713027 0.004016479 -0.0091846066
## hp -0.0062410222 -0.0074713027 0.0568135298 0.003192815 -0.0026470260
## drat 0.0058223040 0.0040164788 0.0031928154 0.055097079 0.0055691773
## wt -0.0024686403 -0.0091846066 -0.0026470260 0.005569177 0.0589471342
## qsec 0.0075639422 0.0025913924 0.0063259120 0.001488753 -0.0064484723
## vs 0.0079914361 0.0050643235 0.0028618316 -0.002787456 -0.0001204504
## am -0.0007384843 0.0028702003 -0.0004960659 -0.004987713 0.0066912080
## gear 0.0061324795 0.0041541267 -0.0044793662 -0.006225893 0.0002478994
## carb -0.0038594492 -0.0006591156 -0.0125736275 -0.004410981 -0.0047021678
## h1 -0.0078853168 -0.0071920042 -0.0117347358 0.002126297 -0.0056497049
## h2 0.0064432462 0.0008427470 0.0054354367 -0.012518603 0.0057446864
## h3 -0.0048297412 -0.0080197306 -0.0029111978 0.003526704 -0.0123647411
## h4 -0.0074352793 -0.0102001681 0.0032640334 0.004292136 -0.0019657223
## h5 0.0052613190 0.0025032477 0.0009585112 -0.002262603 0.0075715755
## qsec vs am gear carb
## cyl 0.0075639422 0.0079914361 -0.0007384843 0.0061324795 -0.0038594492
## disp 0.0025913924 0.0050643235 0.0028702003 0.0041541267 -0.0006591156
## hp 0.0063259120 0.0028618316 -0.0004960659 -0.0044793662 -0.0125736275
## drat 0.0014887532 -0.0027874565 -0.0049877134 -0.0062258929 -0.0044109811
## wt -0.0064484723 -0.0001204504 0.0066912080 0.0002478994 -0.0047021678
## qsec 0.0503112593 -0.0133213299 0.0114424676 0.0037203340 0.0079741656
## vs -0.0133213299 0.0509981024 0.0064243419 -0.0002187248 0.0012511732
## am 0.0114424676 0.0064243419 0.0532933538 -0.0095830826 -0.0064386072
## gear 0.0037203340 -0.0002187248 -0.0095830826 0.0511369322 -0.0127870791
## carb 0.0079741656 0.0012511732 -0.0064386072 -0.0127870791 0.0526739148
## h1 0.0041239477 -0.0003297649 0.0046260028 0.0046215565 -0.0032832747
## h2 -0.0007966727 0.0035853205 -0.0083514300 -0.0056704719 0.0087822325
## h3 -0.0013043556 0.0049873217 0.0012642085 0.0067785183 0.0010486529
## h4 0.0044846921 0.0012511349 0.0071472958 0.0051691971 0.0091636185
## h5 -0.0039500626 -0.0125143313 -0.0078167247 0.0043730045 0.0088466249
## h1 h2 h3 h4 h5
## cyl -7.885317e-03 6.443246e-03 -0.004829741 -7.435279e-03 0.0052613190
## disp -7.192004e-03 8.427470e-04 -0.008019731 -1.020017e-02 0.0025032477
## hp -1.173474e-02 5.435437e-03 -0.002911198 3.264033e-03 0.0009585112
## drat 2.126297e-03 -1.251860e-02 0.003526704 4.292136e-03 -0.0022626032
## wt -5.649705e-03 5.744686e-03 -0.012364741 -1.965722e-03 0.0075715755
## qsec 4.123948e-03 -7.966727e-04 -0.001304356 4.484692e-03 -0.0039500626
## vs -3.297649e-04 3.585321e-03 0.004987322 1.251135e-03 -0.0125143313
## am 4.626003e-03 -8.351430e-03 0.001264209 7.147296e-03 -0.0078167247
## gear 4.621557e-03 -5.670472e-03 0.006778518 5.169197e-03 0.0043730045
## carb -3.283275e-03 8.782232e-03 0.001048653 9.163618e-03 0.0088466249
## h1 5.830277e-02 8.513149e-05 -0.012269418 -9.946594e-05 -0.0014691719
## h2 8.513149e-05 5.490943e-02 -0.002961700 2.567923e-03 -0.0080429381
## h3 -1.226942e-02 -2.961700e-03 0.053107715 3.521041e-03 -0.0026247200
## h4 -9.946594e-05 2.567923e-03 0.003521041 4.677769e-02 0.0036453588
## h5 -1.469172e-03 -8.042938e-03 -0.002624720 3.645359e-03 0.0532094367
##
## $Sigma
## cyl disp hp drat wt
## cyl 0.828472689 -0.068143084 -0.080605901 0.07519795 -0.031883715
## disp -0.068143084 0.804391819 -0.096495585 0.05187482 -0.118623756
## hp -0.080605901 -0.096495585 0.733774954 0.04123680 -0.034187655
## drat 0.075197948 0.051874818 0.041236797 0.71160615 0.071928691
## wt -0.031883715 -0.118623756 -0.034187655 0.07192869 0.761331514
## qsec 0.097692070 0.033469120 0.081702295 0.01922799 -0.083285222
## vs 0.103213366 0.065408254 0.036961976 -0.03600138 -0.001555677
## am -0.009537891 0.037070063 -0.006406938 -0.06441880 0.086420275
## gear 0.079204018 0.053652609 -0.057853240 -0.08041050 0.003201744
## carb -0.049846703 -0.008512805 -0.162394644 -0.05697001 -0.060730832
## h1 -0.101842782 -0.092888307 -0.151559941 0.02746218 -0.072968745
## h2 0.083217724 0.010884496 0.070201364 -0.16168398 0.074195478
## h3 -0.062378506 -0.103578804 -0.037599565 0.04554913 -0.159696772
## h4 -0.096030324 -0.131740236 0.042156613 0.05543506 -0.025388280
## h5 0.067952548 0.032330688 0.012379648 -0.02922264 0.097790658
## qsec vs am gear carb
## cyl 0.09769207 0.103213366 -0.009537891 0.079204018 -0.049846703
## disp 0.03346912 0.065408254 0.037070063 0.053652609 -0.008512805
## hp 0.08170229 0.036961976 -0.006406938 -0.057853240 -0.162394644
## drat 0.01922799 -0.036001385 -0.064418796 -0.080410499 -0.056970012
## wt -0.08328522 -0.001555677 0.086420275 0.003201744 -0.060730832
## qsec 0.64979490 -0.172051591 0.147785152 0.048049961 0.102990309
## vs -0.17205159 0.658665821 0.082973567 -0.002824940 0.016159523
## am 0.14778515 0.082973567 0.688310132 -0.123770271 -0.083157809
## gear 0.04804996 -0.002824940 -0.123770271 0.660458876 -0.165151477
## carb 0.10299031 0.016159523 -0.083157809 -0.165151477 0.680309770
## h1 0.05326283 -0.004259078 0.059747124 0.059689698 -0.042405124
## h2 -0.01028942 0.046306195 -0.107862866 -0.073236961 0.113426894
## h3 -0.01684640 0.064413736 0.016327881 0.087547931 0.013543873
## h4 0.05792203 0.016159029 0.092310874 0.066762747 0.118352683
## h5 -0.05101702 -0.161628804 -0.100956881 0.056479525 0.114258554
## h1 h2 h3 h4 h5
## cyl -0.101842782 0.083217724 -0.06237851 -0.096030324 0.06795255
## disp -0.092888307 0.010884496 -0.10357880 -0.131740236 0.03233069
## hp -0.151559941 0.070201364 -0.03759957 0.042156613 0.01237965
## drat 0.027462183 -0.161683977 0.04554913 0.055435063 -0.02922264
## wt -0.072968745 0.074195478 -0.15969677 -0.025388280 0.09779066
## qsec 0.053262832 -0.010289424 -0.01684640 0.057922025 -0.05101702
## vs -0.004259078 0.046306195 0.06441374 0.016159029 -0.16162880
## am 0.059747124 -0.107862866 0.01632788 0.092310874 -0.10095688
## gear 0.059689698 -0.073236961 0.08754793 0.066762747 0.05647952
## carb -0.042405124 0.113426894 0.01354387 0.118352683 0.11425855
## h1 0.753009224 0.001099515 -0.15846563 -0.001284652 -0.01897509
## h2 0.001099515 0.709182620 -0.03825183 0.033166004 -0.10387854
## h3 -0.158465631 -0.038251826 0.68591252 0.045475988 -0.03389956
## h4 -0.001284652 0.033166004 0.04547599 0.604157095 0.04708162
## h5 -0.018975085 -0.103878541 -0.03389956 0.047081619 0.68722630
##
## $scales
## cyl disp hp drat wt qsec
## 1.7353207 109.0975120 67.7779822 0.5580457 0.8364141 1.8376898
## vs am gear carb h1 h2
## 0.4856209 0.4856209 0.7126966 1.7156080 0.9667548 1.2506310
## h3 h4 h5
## 1.0024913 0.5009253 1.5452324
##
## $lambda
## [1] 12.9155
##
## $ym
## [1] 20.10952
##
## $xm
## cyl disp hp drat wt qsec
## 6.1904762 218.5380952 141.0476190 3.5657143 3.1811429 17.9276190
## vs am gear carb h1 h2
## 0.3809524 0.3809524 3.6666667 2.9047619 0.9011436 0.9011436
## h3 h4 h5
## 0.7713531 0.4177551 1.0757344
##
## $n_clusters
## [1] 0
##
## $clusters_scales
## $clusters_scales$means
## cyl disp hp drat wt qsec
## 6.1904762 218.5380952 141.0476190 3.5657143 3.1811429 17.9276190
## vs am gear carb
## 0.3809524 0.3809524 3.6666667 2.9047619
##
## $clusters_scales$sds
## cyl disp hp drat wt qsec
## 1.7353207 109.0975120 67.7779822 0.5580457 0.8364141 1.8376898
## vs am gear carb
## 0.4856209 0.4856209 0.7126966 1.7156080
##
##
## $clust_obj
## NULL
##
## $nb_hidden
## [1] 5
##
## $nodes_sim
## [1] "sobol"
##
## $activ
## [1] "relu"
##
## $nn_xm
## cyl disp hp drat wt qsec
## 6.1904762 218.5380952 141.0476190 3.5657143 3.1811429 17.9276190
## vs am gear carb
## 0.3809524 0.3809524 3.6666667 2.9047619
##
## $nn_scales
## cyl disp hp drat wt qsec
## 1.7353207 109.0975120 67.7779822 0.5580457 0.8364141 1.8376898
## vs am gear carb
## 0.4856209 0.4856209 0.7126966 1.7156080
##
## $fitted_values
## Mazda RX4 Mazda RX4 Wag Valiant Duster 360
## 21.79908 21.59131 19.67633 14.78088
## Merc 240D Merc 230 Merc 280 Merc 450SE
## 22.79888 23.52919 19.78664 15.28762
## Merc 450SL Merc 450SLC Lincoln Continental Fiat 128
## 15.69314 15.62712 12.41438 27.77353
## Honda Civic Toyota Corona Dodge Challenger AMC Javelin
## 29.01498 23.50423 16.25692 17.07550
## Pontiac Firebird Fiat X1-9 Porsche 914-2 Ferrari Dino
## 16.08283 27.99151 26.24130 19.96917
## Maserati Bora
## 15.40547
##
## $compute_Sigma
## [1] TRUE
##
## $x
## cyl disp hp drat wt qsec vs am gear carb
## Mazda RX4 6 160.0 110 3.90 2.620 16.46 0 1 4 4
## Mazda RX4 Wag 6 160.0 110 3.90 2.875 17.02 0 1 4 4
## Valiant 6 225.0 105 2.76 3.460 20.22 1 0 3 1
## Duster 360 8 360.0 245 3.21 3.570 15.84 0 0 3 4
## Merc 240D 4 146.7 62 3.69 3.190 20.00 1 0 4 2
## Merc 230 4 140.8 95 3.92 3.150 22.90 1 0 4 2
## Merc 280 6 167.6 123 3.92 3.440 18.30 1 0 4 4
## Merc 450SE 8 275.8 180 3.07 4.070 17.40 0 0 3 3
## Merc 450SL 8 275.8 180 3.07 3.730 17.60 0 0 3 3
## Merc 450SLC 8 275.8 180 3.07 3.780 18.00 0 0 3 3
## Lincoln Continental 8 460.0 215 3.00 5.424 17.82 0 0 3 4
## Fiat 128 4 78.7 66 4.08 2.200 19.47 1 1 4 1
## Honda Civic 4 75.7 52 4.93 1.615 18.52 1 1 4 2
## Toyota Corona 4 120.1 97 3.70 2.465 20.01 1 0 3 1
## Dodge Challenger 8 318.0 150 2.76 3.520 16.87 0 0 3 2
## AMC Javelin 8 304.0 150 3.15 3.435 17.30 0 0 3 2
## Pontiac Firebird 8 400.0 175 3.08 3.845 17.05 0 0 3 2
## Fiat X1-9 4 79.0 66 4.08 1.935 18.90 1 1 4 1
## Porsche 914-2 4 120.3 91 4.43 2.140 16.70 0 1 5 2
## Ferrari Dino 6 145.0 175 3.62 2.770 15.50 0 1 5 6
## Maserati Bora 8 301.0 335 3.54 3.570 14.60 0 1 5 8
##
## $y
## [1] 21.0 21.0 18.1 14.3 24.4 22.8 19.2 16.4 17.3 15.2 10.4 32.4 30.4 21.5 15.5
## [16] 15.2 19.2 27.3 26.0 19.7 15.0
##
## $n_updates
## [1] 0
##
## $avg_coefs
## [,1]
## cyl -0.36321389
## disp -0.41963495
## hp -0.47508218
## drat 0.55766798
## wt -0.54482177
## qsec 0.09679212
## vs 0.32281808
## am 0.48049052
## gear 0.25888698
## carb -0.47461167
## h1 -0.32480888
## h2 0.75205757
## h3 -0.35831420
## h4 0.04812892
## h5 0.72081638
##
## attr(,"class")
## [1] "rvfl"
# 3 - Predict on validation set -----
preds_validation <- bayesianrvfl::predict_rvfl(fit_obj,
newx = X_validation)
print(preds_validation)
## $mean
## Cadillac Fleetwood Lotus Europa Ford Pantera L Chrysler Imperial
## 12.60216 26.90076 19.19089 12.77321
## Volvo 142E
## 25.33738
##
## $sd
## Cadillac Fleetwood Lotus Europa Ford Pantera L Chrysler Imperial
## 4.224966 4.228792 4.532150 4.180162
## Volvo 142E
## 4.041346
##
## $simulate
## function (n)
## MASS::mvrnorm(n = n, mu = res, Sigma = Sigma_newx)
## <bytecode: 0x7fb55fca7640>
## <environment: 0x7fb55fca6ae0>
level <- 95
multiplier <- qnorm(1 - (100 - level)/200)
summary(preds_validation$mean - y_validation)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -3.4992 -1.9268 2.2022 0.8209 3.3909 3.9374
preds_upper <- preds_validation$mean + multiplier*preds_validation$sd
preds_lower <- preds_validation$mean - multiplier*preds_validation$sd
# coverage rate
mean((preds_upper >= y_validation)*(preds_lower <= y_validation))
## [1] 1
# 4 - Update -----
# add new points in an online fashion
fit_obj2 <- bayesianrvfl::update_params(fit_obj, newx = X[idx_test[1], ],
newy = y[idx_test[1]],
method = "polyak")
print(fit_obj2)
## $coef
## [,1]
## cyl 0.156616913
## disp -0.002381108
## hp -0.183146989
## drat 0.347876182
## wt -0.120830328
## qsec -0.056125629
## vs -0.202146379
## am -0.044473938
## gear 0.066277686
## carb -0.017390995
## h1 0.059058223
## h2 0.316744350
## h3 -0.041448067
## h4 0.391569783
## h5 -0.055084236
##
## $Dn
## cyl disp hp drat wt
## cyl 0.0641456316 -0.0052760715 -0.0062410222 0.005822304 -0.0024686403
## disp -0.0052760715 0.0622811372 -0.0074713027 0.004016479 -0.0091846066
## hp -0.0062410222 -0.0074713027 0.0568135298 0.003192815 -0.0026470260
## drat 0.0058223040 0.0040164788 0.0031928154 0.055097079 0.0055691773
## wt -0.0024686403 -0.0091846066 -0.0026470260 0.005569177 0.0589471342
## qsec 0.0075639422 0.0025913924 0.0063259120 0.001488753 -0.0064484723
## vs 0.0079914361 0.0050643235 0.0028618316 -0.002787456 -0.0001204504
## am -0.0007384843 0.0028702003 -0.0004960659 -0.004987713 0.0066912080
## gear 0.0061324795 0.0041541267 -0.0044793662 -0.006225893 0.0002478994
## carb -0.0038594492 -0.0006591156 -0.0125736275 -0.004410981 -0.0047021678
## h1 -0.0078853168 -0.0071920042 -0.0117347358 0.002126297 -0.0056497049
## h2 0.0064432462 0.0008427470 0.0054354367 -0.012518603 0.0057446864
## h3 -0.0048297412 -0.0080197306 -0.0029111978 0.003526704 -0.0123647411
## h4 -0.0074352793 -0.0102001681 0.0032640334 0.004292136 -0.0019657223
## h5 0.0052613190 0.0025032477 0.0009585112 -0.002262603 0.0075715755
## qsec vs am gear carb
## cyl 0.0075639422 0.0079914361 -0.0007384843 0.0061324795 -0.0038594492
## disp 0.0025913924 0.0050643235 0.0028702003 0.0041541267 -0.0006591156
## hp 0.0063259120 0.0028618316 -0.0004960659 -0.0044793662 -0.0125736275
## drat 0.0014887532 -0.0027874565 -0.0049877134 -0.0062258929 -0.0044109811
## wt -0.0064484723 -0.0001204504 0.0066912080 0.0002478994 -0.0047021678
## qsec 0.0503112593 -0.0133213299 0.0114424676 0.0037203340 0.0079741656
## vs -0.0133213299 0.0509981024 0.0064243419 -0.0002187248 0.0012511732
## am 0.0114424676 0.0064243419 0.0532933538 -0.0095830826 -0.0064386072
## gear 0.0037203340 -0.0002187248 -0.0095830826 0.0511369322 -0.0127870791
## carb 0.0079741656 0.0012511732 -0.0064386072 -0.0127870791 0.0526739148
## h1 0.0041239477 -0.0003297649 0.0046260028 0.0046215565 -0.0032832747
## h2 -0.0007966727 0.0035853205 -0.0083514300 -0.0056704719 0.0087822325
## h3 -0.0013043556 0.0049873217 0.0012642085 0.0067785183 0.0010486529
## h4 0.0044846921 0.0012511349 0.0071472958 0.0051691971 0.0091636185
## h5 -0.0039500626 -0.0125143313 -0.0078167247 0.0043730045 0.0088466249
## h1 h2 h3 h4 h5
## cyl -7.885317e-03 6.443246e-03 -0.004829741 -7.435279e-03 0.0052613190
## disp -7.192004e-03 8.427470e-04 -0.008019731 -1.020017e-02 0.0025032477
## hp -1.173474e-02 5.435437e-03 -0.002911198 3.264033e-03 0.0009585112
## drat 2.126297e-03 -1.251860e-02 0.003526704 4.292136e-03 -0.0022626032
## wt -5.649705e-03 5.744686e-03 -0.012364741 -1.965722e-03 0.0075715755
## qsec 4.123948e-03 -7.966727e-04 -0.001304356 4.484692e-03 -0.0039500626
## vs -3.297649e-04 3.585321e-03 0.004987322 1.251135e-03 -0.0125143313
## am 4.626003e-03 -8.351430e-03 0.001264209 7.147296e-03 -0.0078167247
## gear 4.621557e-03 -5.670472e-03 0.006778518 5.169197e-03 0.0043730045
## carb -3.283275e-03 8.782232e-03 0.001048653 9.163618e-03 0.0088466249
## h1 5.830277e-02 8.513149e-05 -0.012269418 -9.946594e-05 -0.0014691719
## h2 8.513149e-05 5.490943e-02 -0.002961700 2.567923e-03 -0.0080429381
## h3 -1.226942e-02 -2.961700e-03 0.053107715 3.521041e-03 -0.0026247200
## h4 -9.946594e-05 2.567923e-03 0.003521041 4.677769e-02 0.0036453588
## h5 -1.469172e-03 -8.042938e-03 -0.002624720 3.645359e-03 0.0532094367
##
## $Sigma
## cyl disp hp drat wt
## x1 0.78402288 -0.064247705 -0.028087737 -0.0119358174 -0.015798304
## x2 -0.10382172 0.807518533 -0.054340704 -0.0180650516 -0.105712440
## x3 -0.10556876 -0.094307953 0.763268974 -0.0076972266 -0.025154143
## x4 0.09313687 0.050302732 0.020041669 0.7467713366 0.065436987
## x5 -0.06813847 -0.115446553 0.008647921 0.0008594692 0.774451315
## x6 0.11076780 0.032323223 0.066253115 0.0448599798 -0.088017041
## x7 0.14810215 0.061474405 -0.016074837 0.0519928805 -0.017799942
## x8 0.03535089 0.033136214 -0.059443751 0.0235754698 0.070176010
## x9 0.09567370 0.052209281 -0.077312430 -0.0481254296 -0.002758272
## x10 -0.08894283 -0.005086598 -0.116201940 -0.1336091004 -0.046582800
## h1 -0.13466658 -0.090011778 -0.112778100 -0.0368814186 -0.061090533
## h2 0.12044059 0.007622454 0.026221957 -0.0887170054 0.060725341
## h3 -0.08947317 -0.101204351 -0.005586788 -0.0075638042 -0.149891808
## h4 -0.12539734 -0.129166644 0.076854217 -0.0021323131 -0.014761000
## h5 0.13429844 0.026516436 -0.066009081 0.1008334081 0.073781537
## qsec vs am gear carb h1
## x1 0.044498822 0.188085424 0.12275509 0.09062335 -0.15836877 -0.109763941
## x2 -0.009227633 0.133532713 0.14325799 0.06281859 -0.09562047 -0.099246402
## x3 0.051829150 0.084625833 0.06788835 -0.05144018 -0.22334026 -0.156008436
## x4 0.040695563 -0.070253803 -0.11780921 -0.08501908 -0.01317299 0.030658981
## x5 -0.126671420 0.067668820 0.19432287 0.01251574 -0.14924507 -0.079429508
## x6 0.665442669 -0.197018262 0.10886875 0.04469075 0.13491406 0.055592986
## x7 -0.118333027 0.572955599 -0.05062589 -0.01435705 0.12575331 0.003740308
## x8 0.201503716 -0.002736656 0.55471067 -0.13530238 0.02643598 0.067746509
## x9 0.067759286 -0.034271994 -0.17278787 0.65622775 -0.12494155 0.062624670
## x10 0.056203830 0.090809312 0.03320137 -0.15510752 0.58485846 -0.049372232
## h1 0.013982478 0.058414375 0.15743838 0.06812226 -0.12254284 0.747159874
## h2 0.034255308 -0.024766793 -0.21864677 -0.08279966 0.20430471 0.007732799
## h3 -0.049270680 0.116148033 0.09696790 0.09450866 -0.05260644 -0.163294023
## h4 0.022778408 0.072232144 0.17971396 0.07430725 0.04665452 -0.006517989
## h5 0.028379340 -0.288309027 -0.29841769 0.03943500 0.27623883 -0.007151946
## h2 h3 h4 h5
## x1 0.079803966 -0.06285836 -0.11671117 0.203615337
## x2 0.008144366 -0.10396397 -0.14834018 0.141223465
## x3 0.068284209 -0.03786905 0.03054232 0.088567404
## x4 -0.160306263 0.04574279 0.06378138 -0.083973038
## x5 0.071411102 -0.16008816 -0.04225627 0.208441779
## x6 -0.009285204 -0.01670524 0.06400567 -0.090924715
## x7 0.049753667 0.06489833 0.03704411 -0.298631347
## x8 -0.104415394 0.01681248 0.11319596 -0.237959424
## x9 -0.071972084 0.08772573 0.07442548 0.006213334
## x10 0.110424300 0.01312181 0.10016271 0.233581671
## h1 -0.001421362 -0.15881998 -0.01655635 0.081204591
## h2 0.712041347 -0.03784999 0.05048442 -0.217484361
## h3 -0.040332704 0.68562002 0.03286985 0.048794544
## h4 0.030910608 0.04515896 0.59049371 0.136711061
## h5 -0.098783157 -0.03318333 0.07794990 0.484735720
##
## $scales
## cyl disp hp drat wt qsec
## 1.7557462 109.0477139 66.9717006 0.5484217 0.8366388 1.8010560
## vs am gear carb h1 h2
## 0.4916661 0.4916661 0.6997638 1.7224814 1.0168174 1.2212131
## h3 h4 h5
## 1.0386470 0.5080586 1.5540760
##
## $lambda
## [1] 12.9155
##
## $ym
## [1] 20.23182
##
## $xm
## cyl disp hp drat wt qsec
## 6.0909091 213.5136364 138.8636364 3.5786364 3.1420000 17.9586364
## vs am gear carb h1 h2
## 0.4090909 0.4090909 3.6818182 2.8181818 0.9168722 0.9168722
## h3 h4 h5
## 0.7979139 0.4134997 1.1133050
##
## $n_clusters
## [1] 0
##
## $clusters_scales
## $clusters_scales$means
## cyl disp hp drat wt qsec
## 6.1904762 218.5380952 141.0476190 3.5657143 3.1811429 17.9276190
## vs am gear carb
## 0.3809524 0.3809524 3.6666667 2.9047619
##
## $clusters_scales$sds
## cyl disp hp drat wt qsec
## 1.7353207 109.0975120 67.7779822 0.5580457 0.8364141 1.8376898
## vs am gear carb
## 0.4856209 0.4856209 0.7126966 1.7156080
##
##
## $clust_obj
## NULL
##
## $nb_hidden
## [1] 5
##
## $nodes_sim
## [1] "sobol"
##
## $activ
## [1] "relu"
##
## $nn_xm
## cyl disp hp drat wt qsec
## 6.0909091 213.5136364 138.8636364 3.5786364 3.1420000 17.9586364
## vs am gear carb
## 0.4090909 0.4090909 3.6818182 2.8181818
##
## $nn_scales
## cyl disp hp drat wt qsec
## 1.7557462 109.0477139 66.9717006 0.5484217 0.8366388 1.8010560
## vs am gear carb
## 0.4916661 0.4916661 0.6997638 1.7224814
##
## $fitted_values
## Mazda RX4 Mazda RX4 Wag Valiant Duster 360
## 20.49798 20.44760 19.29045 20.05476
## Merc 240D Merc 230 Merc 280 Merc 450SE
## 20.07872 19.91361 20.30749 19.87505
## Merc 450SL Merc 450SLC Lincoln Continental Fiat 128
## 19.96864 19.88149 19.65522 20.42344
## Honda Civic Toyota Corona Dodge Challenger AMC Javelin
## 21.28915 19.54064 20.53654 20.80859
## Pontiac Firebird Fiat X1-9 Porsche 914-2 Ferrari Dino
## 20.84109 20.47431 21.54739 19.96616
## Maserati Bora newx
## 19.62646 20.07523
##
## $compute_Sigma
## [1] TRUE
##
## $x
## cyl disp hp drat wt qsec vs am gear carb
## Mazda RX4 6 160.0 110 3.90 2.620 16.46 0 1 4 4
## Mazda RX4 Wag 6 160.0 110 3.90 2.875 17.02 0 1 4 4
## Valiant 6 225.0 105 2.76 3.460 20.22 1 0 3 1
## Duster 360 8 360.0 245 3.21 3.570 15.84 0 0 3 4
## Merc 240D 4 146.7 62 3.69 3.190 20.00 1 0 4 2
## Merc 230 4 140.8 95 3.92 3.150 22.90 1 0 4 2
## Merc 280 6 167.6 123 3.92 3.440 18.30 1 0 4 4
## Merc 450SE 8 275.8 180 3.07 4.070 17.40 0 0 3 3
## Merc 450SL 8 275.8 180 3.07 3.730 17.60 0 0 3 3
## Merc 450SLC 8 275.8 180 3.07 3.780 18.00 0 0 3 3
## Lincoln Continental 8 460.0 215 3.00 5.424 17.82 0 0 3 4
## Fiat 128 4 78.7 66 4.08 2.200 19.47 1 1 4 1
## Honda Civic 4 75.7 52 4.93 1.615 18.52 1 1 4 2
## Toyota Corona 4 120.1 97 3.70 2.465 20.01 1 0 3 1
## Dodge Challenger 8 318.0 150 2.76 3.520 16.87 0 0 3 2
## AMC Javelin 8 304.0 150 3.15 3.435 17.30 0 0 3 2
## Pontiac Firebird 8 400.0 175 3.08 3.845 17.05 0 0 3 2
## Fiat X1-9 4 79.0 66 4.08 1.935 18.90 1 1 4 1
## Porsche 914-2 4 120.3 91 4.43 2.140 16.70 0 1 5 2
## Ferrari Dino 6 145.0 175 3.62 2.770 15.50 0 1 5 6
## Maserati Bora 8 301.0 335 3.54 3.570 14.60 0 1 5 8
## newx 4 108.0 93 3.85 2.320 18.61 1 1 4 1
##
## $y
## [1] 21.0 21.0 18.1 14.3 24.4 22.8 19.2 16.4 17.3 15.2 10.4 32.4 30.4 21.5 15.5
## [16] 15.2 19.2 27.3 26.0 19.7 15.0 22.8
##
## $n_updates
## [1] 1
##
## $avg_coefs
## [,1]
## cyl 0.156616913
## disp -0.002381108
## hp -0.183146989
## drat 0.347876182
## wt -0.120830328
## qsec -0.056125629
## vs -0.202146379
## am -0.044473938
## gear 0.066277686
## carb -0.017390995
## h1 0.059058223
## h2 0.316744350
## h3 -0.041448067
## h4 0.391569783
## h5 -0.055084236
##
## attr(,"class")
## [1] "rvfl"
# 4 - Update -----
fit_obj3 <- bayesianrvfl::update_params(fit_obj2, newx = X[idx_test[2], ],
newy = y[idx_test[2]])
fit_obj4 <- bayesianrvfl::update_params(fit_obj3, newx = X[idx_test[3], ],
newy = y[idx_test[3]])
fit_obj5 <- bayesianrvfl::update_params(fit_obj4, newx = X[idx_test[4], ],
newy = y[idx_test[4]])
fit_obj6 <- bayesianrvfl::update_params(fit_obj5, newx = X[idx_test[5], ],
newy = y[idx_test[5]])
fit_obj7 <- bayesianrvfl::update_params(fit_obj6, newx = X[idx_test[6], ],
newy = y[idx_test[6]])
(mat_coefs <- cbind(fit_obj$coef, fit_obj2$coef,
fit_obj3$coef, fit_obj4$coef,
fit_obj5$coef, fit_obj6$coef,
fit_obj7$coef))
## [,1] [,2] [,3] [,4] [,5] [,6]
## cyl -0.36321389 0.156616913 0.15526284 0.129827202 0.108708432 0.25553641
## disp -0.41963495 -0.002381108 0.02392163 -0.028084515 0.010614134 -0.05287947
## hp -0.47508218 -0.183146989 -0.18017625 -0.209942490 -0.185639455 -0.14172577
## drat 0.55766798 0.347876182 0.31677310 0.296955079 0.218041531 0.23138318
## wt -0.54482177 -0.120830328 -0.12408786 -0.094487284 -0.130846130 -0.33058489
## qsec 0.09679212 -0.056125629 -0.04983535 -0.038472349 -0.052940428 0.18643092
## vs 0.32281808 -0.202146379 -0.15477060 -0.121186864 -0.215509601 -0.20256839
## am 0.48049052 -0.044473938 -0.03601185 -0.027533562 0.027790516 0.23768575
## gear 0.25888698 0.066277686 0.05286751 0.066572603 0.032719777 0.03144328
## carb -0.47461167 -0.017390995 -0.03329515 -0.000839621 -0.059773632 -0.07664043
## h1 -0.32480888 0.059058223 0.04693134 0.083068120 0.070315911 0.21418578
## h2 0.75205757 0.316744350 0.29130604 0.288692930 0.339631245 0.60881874
## h3 -0.35831420 -0.041448067 -0.09083627 -0.040985378 -0.003155133 0.08415602
## h4 0.04812892 0.391569783 0.43137202 0.325474945 0.293790804 0.38395787
## h5 0.72081638 -0.055084236 -0.05131442 -0.050581055 0.009582396 0.47876079
## [,7]
## cyl 0.26362315
## disp -0.13067584
## hp -0.32368770
## drat -0.06532282
## wt -0.39288401
## qsec 0.37268019
## vs -0.19005954
## am 0.37007213
## gear 0.24037714
## carb -0.05610335
## h1 0.01489951
## h2 0.59106285
## h3 0.31512079
## h4 0.44075787
## h5 0.42961755
preds_validation2 <- bayesianrvfl::predict_rvfl(fit_obj2,
newx = X_validation)
preds_validation3 <- bayesianrvfl::predict_rvfl(fit_obj7,
newx = X_validation)
# 5 - Plots -----
#par(mfrow=c(3, 2))
plot(x = log(obj_GCV$lambda), y = obj_GCV$GCV, type='l',
main = 'Generalized Cross-validation error',
xlab = "log(lambda)", ylab = "GCV")

plot(y_validation, type='l', col="red",
lwd=2,
ylim = c(min(c(y_validation, preds_lower, preds_validation$mean)),
max(c(y_validation, preds_upper, preds_validation$mean))),
main = 'Out-of-sample credible intervals',
xlab = "obs#", ylab = "prediction")
lines(preds_validation$mean, col="blue", lwd=2)
lines(preds_upper, col="gray")
lines(preds_lower, col="gray")

plot(x = y_validation,
y = preds_validation$mean,
ylim = c(min(c(y_validation, preds_lower, preds_validation$mean)),
max(c(y_validation, preds_upper, preds_validation$mean))),
main = 'observed vs predicted \n (before updates)',
xlab = "observed", ylab = "predicted")
abline(a = 0, b = 1, col="green", lwd=2)

matplot(t(mat_coefs), type = 'l', lwd = 2,
main = 'model coefficients \n after each update',
xlab = "update#", ylab = "model coefficient")

plot(x = y_validation,
y = preds_validation2$mean,
ylim = c(min(c(y_validation, preds_lower, preds_validation$mean)),
max(c(y_validation, preds_upper, preds_validation$mean))),
main = 'observed vs predicted \n (after 6 point-updates)',
xlab = "observed", ylab = "predicted")
abline(a = 0, b = 1, col="green", lwd=2)

matplot(t(preds_validation3$simulate(250)), type='l',
lwd = 2, main = 'predictive posterior simulation \n (after 6 point-updates)',
xlab = "obs#", ylab = "prediction")

2 - With Polyak averaging
library("bayesianrvfl")
(fit_obj <- bayesianrvfl::fit_rvfl(x = X_train,
y = y_train,
method = "solve",
lambda = best_lambda,
compute_Sigma = TRUE))
## $coef
## [,1]
## cyl -0.36321389
## disp -0.41963495
## hp -0.47508218
## drat 0.55766798
## wt -0.54482177
## qsec 0.09679212
## vs 0.32281808
## am 0.48049052
## gear 0.25888698
## carb -0.47461167
## h1 -0.32480888
## h2 0.75205757
## h3 -0.35831420
## h4 0.04812892
## h5 0.72081638
##
## $Dn
## cyl disp hp drat wt
## cyl 0.0641456316 -0.0052760715 -0.0062410222 0.005822304 -0.0024686403
## disp -0.0052760715 0.0622811372 -0.0074713027 0.004016479 -0.0091846066
## hp -0.0062410222 -0.0074713027 0.0568135298 0.003192815 -0.0026470260
## drat 0.0058223040 0.0040164788 0.0031928154 0.055097079 0.0055691773
## wt -0.0024686403 -0.0091846066 -0.0026470260 0.005569177 0.0589471342
## qsec 0.0075639422 0.0025913924 0.0063259120 0.001488753 -0.0064484723
## vs 0.0079914361 0.0050643235 0.0028618316 -0.002787456 -0.0001204504
## am -0.0007384843 0.0028702003 -0.0004960659 -0.004987713 0.0066912080
## gear 0.0061324795 0.0041541267 -0.0044793662 -0.006225893 0.0002478994
## carb -0.0038594492 -0.0006591156 -0.0125736275 -0.004410981 -0.0047021678
## h1 -0.0078853168 -0.0071920042 -0.0117347358 0.002126297 -0.0056497049
## h2 0.0064432462 0.0008427470 0.0054354367 -0.012518603 0.0057446864
## h3 -0.0048297412 -0.0080197306 -0.0029111978 0.003526704 -0.0123647411
## h4 -0.0074352793 -0.0102001681 0.0032640334 0.004292136 -0.0019657223
## h5 0.0052613190 0.0025032477 0.0009585112 -0.002262603 0.0075715755
## qsec vs am gear carb
## cyl 0.0075639422 0.0079914361 -0.0007384843 0.0061324795 -0.0038594492
## disp 0.0025913924 0.0050643235 0.0028702003 0.0041541267 -0.0006591156
## hp 0.0063259120 0.0028618316 -0.0004960659 -0.0044793662 -0.0125736275
## drat 0.0014887532 -0.0027874565 -0.0049877134 -0.0062258929 -0.0044109811
## wt -0.0064484723 -0.0001204504 0.0066912080 0.0002478994 -0.0047021678
## qsec 0.0503112593 -0.0133213299 0.0114424676 0.0037203340 0.0079741656
## vs -0.0133213299 0.0509981024 0.0064243419 -0.0002187248 0.0012511732
## am 0.0114424676 0.0064243419 0.0532933538 -0.0095830826 -0.0064386072
## gear 0.0037203340 -0.0002187248 -0.0095830826 0.0511369322 -0.0127870791
## carb 0.0079741656 0.0012511732 -0.0064386072 -0.0127870791 0.0526739148
## h1 0.0041239477 -0.0003297649 0.0046260028 0.0046215565 -0.0032832747
## h2 -0.0007966727 0.0035853205 -0.0083514300 -0.0056704719 0.0087822325
## h3 -0.0013043556 0.0049873217 0.0012642085 0.0067785183 0.0010486529
## h4 0.0044846921 0.0012511349 0.0071472958 0.0051691971 0.0091636185
## h5 -0.0039500626 -0.0125143313 -0.0078167247 0.0043730045 0.0088466249
## h1 h2 h3 h4 h5
## cyl -7.885317e-03 6.443246e-03 -0.004829741 -7.435279e-03 0.0052613190
## disp -7.192004e-03 8.427470e-04 -0.008019731 -1.020017e-02 0.0025032477
## hp -1.173474e-02 5.435437e-03 -0.002911198 3.264033e-03 0.0009585112
## drat 2.126297e-03 -1.251860e-02 0.003526704 4.292136e-03 -0.0022626032
## wt -5.649705e-03 5.744686e-03 -0.012364741 -1.965722e-03 0.0075715755
## qsec 4.123948e-03 -7.966727e-04 -0.001304356 4.484692e-03 -0.0039500626
## vs -3.297649e-04 3.585321e-03 0.004987322 1.251135e-03 -0.0125143313
## am 4.626003e-03 -8.351430e-03 0.001264209 7.147296e-03 -0.0078167247
## gear 4.621557e-03 -5.670472e-03 0.006778518 5.169197e-03 0.0043730045
## carb -3.283275e-03 8.782232e-03 0.001048653 9.163618e-03 0.0088466249
## h1 5.830277e-02 8.513149e-05 -0.012269418 -9.946594e-05 -0.0014691719
## h2 8.513149e-05 5.490943e-02 -0.002961700 2.567923e-03 -0.0080429381
## h3 -1.226942e-02 -2.961700e-03 0.053107715 3.521041e-03 -0.0026247200
## h4 -9.946594e-05 2.567923e-03 0.003521041 4.677769e-02 0.0036453588
## h5 -1.469172e-03 -8.042938e-03 -0.002624720 3.645359e-03 0.0532094367
##
## $Sigma
## cyl disp hp drat wt
## cyl 0.828472689 -0.068143084 -0.080605901 0.07519795 -0.031883715
## disp -0.068143084 0.804391819 -0.096495585 0.05187482 -0.118623756
## hp -0.080605901 -0.096495585 0.733774954 0.04123680 -0.034187655
## drat 0.075197948 0.051874818 0.041236797 0.71160615 0.071928691
## wt -0.031883715 -0.118623756 -0.034187655 0.07192869 0.761331514
## qsec 0.097692070 0.033469120 0.081702295 0.01922799 -0.083285222
## vs 0.103213366 0.065408254 0.036961976 -0.03600138 -0.001555677
## am -0.009537891 0.037070063 -0.006406938 -0.06441880 0.086420275
## gear 0.079204018 0.053652609 -0.057853240 -0.08041050 0.003201744
## carb -0.049846703 -0.008512805 -0.162394644 -0.05697001 -0.060730832
## h1 -0.101842782 -0.092888307 -0.151559941 0.02746218 -0.072968745
## h2 0.083217724 0.010884496 0.070201364 -0.16168398 0.074195478
## h3 -0.062378506 -0.103578804 -0.037599565 0.04554913 -0.159696772
## h4 -0.096030324 -0.131740236 0.042156613 0.05543506 -0.025388280
## h5 0.067952548 0.032330688 0.012379648 -0.02922264 0.097790658
## qsec vs am gear carb
## cyl 0.09769207 0.103213366 -0.009537891 0.079204018 -0.049846703
## disp 0.03346912 0.065408254 0.037070063 0.053652609 -0.008512805
## hp 0.08170229 0.036961976 -0.006406938 -0.057853240 -0.162394644
## drat 0.01922799 -0.036001385 -0.064418796 -0.080410499 -0.056970012
## wt -0.08328522 -0.001555677 0.086420275 0.003201744 -0.060730832
## qsec 0.64979490 -0.172051591 0.147785152 0.048049961 0.102990309
## vs -0.17205159 0.658665821 0.082973567 -0.002824940 0.016159523
## am 0.14778515 0.082973567 0.688310132 -0.123770271 -0.083157809
## gear 0.04804996 -0.002824940 -0.123770271 0.660458876 -0.165151477
## carb 0.10299031 0.016159523 -0.083157809 -0.165151477 0.680309770
## h1 0.05326283 -0.004259078 0.059747124 0.059689698 -0.042405124
## h2 -0.01028942 0.046306195 -0.107862866 -0.073236961 0.113426894
## h3 -0.01684640 0.064413736 0.016327881 0.087547931 0.013543873
## h4 0.05792203 0.016159029 0.092310874 0.066762747 0.118352683
## h5 -0.05101702 -0.161628804 -0.100956881 0.056479525 0.114258554
## h1 h2 h3 h4 h5
## cyl -0.101842782 0.083217724 -0.06237851 -0.096030324 0.06795255
## disp -0.092888307 0.010884496 -0.10357880 -0.131740236 0.03233069
## hp -0.151559941 0.070201364 -0.03759957 0.042156613 0.01237965
## drat 0.027462183 -0.161683977 0.04554913 0.055435063 -0.02922264
## wt -0.072968745 0.074195478 -0.15969677 -0.025388280 0.09779066
## qsec 0.053262832 -0.010289424 -0.01684640 0.057922025 -0.05101702
## vs -0.004259078 0.046306195 0.06441374 0.016159029 -0.16162880
## am 0.059747124 -0.107862866 0.01632788 0.092310874 -0.10095688
## gear 0.059689698 -0.073236961 0.08754793 0.066762747 0.05647952
## carb -0.042405124 0.113426894 0.01354387 0.118352683 0.11425855
## h1 0.753009224 0.001099515 -0.15846563 -0.001284652 -0.01897509
## h2 0.001099515 0.709182620 -0.03825183 0.033166004 -0.10387854
## h3 -0.158465631 -0.038251826 0.68591252 0.045475988 -0.03389956
## h4 -0.001284652 0.033166004 0.04547599 0.604157095 0.04708162
## h5 -0.018975085 -0.103878541 -0.03389956 0.047081619 0.68722630
##
## $scales
## cyl disp hp drat wt qsec
## 1.7353207 109.0975120 67.7779822 0.5580457 0.8364141 1.8376898
## vs am gear carb h1 h2
## 0.4856209 0.4856209 0.7126966 1.7156080 0.9667548 1.2506310
## h3 h4 h5
## 1.0024913 0.5009253 1.5452324
##
## $lambda
## [1] 12.9155
##
## $ym
## [1] 20.10952
##
## $xm
## cyl disp hp drat wt qsec
## 6.1904762 218.5380952 141.0476190 3.5657143 3.1811429 17.9276190
## vs am gear carb h1 h2
## 0.3809524 0.3809524 3.6666667 2.9047619 0.9011436 0.9011436
## h3 h4 h5
## 0.7713531 0.4177551 1.0757344
##
## $n_clusters
## [1] 0
##
## $clusters_scales
## $clusters_scales$means
## cyl disp hp drat wt qsec
## 6.1904762 218.5380952 141.0476190 3.5657143 3.1811429 17.9276190
## vs am gear carb
## 0.3809524 0.3809524 3.6666667 2.9047619
##
## $clusters_scales$sds
## cyl disp hp drat wt qsec
## 1.7353207 109.0975120 67.7779822 0.5580457 0.8364141 1.8376898
## vs am gear carb
## 0.4856209 0.4856209 0.7126966 1.7156080
##
##
## $clust_obj
## NULL
##
## $nb_hidden
## [1] 5
##
## $nodes_sim
## [1] "sobol"
##
## $activ
## [1] "relu"
##
## $nn_xm
## cyl disp hp drat wt qsec
## 6.1904762 218.5380952 141.0476190 3.5657143 3.1811429 17.9276190
## vs am gear carb
## 0.3809524 0.3809524 3.6666667 2.9047619
##
## $nn_scales
## cyl disp hp drat wt qsec
## 1.7353207 109.0975120 67.7779822 0.5580457 0.8364141 1.8376898
## vs am gear carb
## 0.4856209 0.4856209 0.7126966 1.7156080
##
## $fitted_values
## Mazda RX4 Mazda RX4 Wag Valiant Duster 360
## 21.79908 21.59131 19.67633 14.78088
## Merc 240D Merc 230 Merc 280 Merc 450SE
## 22.79888 23.52919 19.78664 15.28762
## Merc 450SL Merc 450SLC Lincoln Continental Fiat 128
## 15.69314 15.62712 12.41438 27.77353
## Honda Civic Toyota Corona Dodge Challenger AMC Javelin
## 29.01498 23.50423 16.25692 17.07550
## Pontiac Firebird Fiat X1-9 Porsche 914-2 Ferrari Dino
## 16.08283 27.99151 26.24130 19.96917
## Maserati Bora
## 15.40547
##
## $compute_Sigma
## [1] TRUE
##
## $x
## cyl disp hp drat wt qsec vs am gear carb
## Mazda RX4 6 160.0 110 3.90 2.620 16.46 0 1 4 4
## Mazda RX4 Wag 6 160.0 110 3.90 2.875 17.02 0 1 4 4
## Valiant 6 225.0 105 2.76 3.460 20.22 1 0 3 1
## Duster 360 8 360.0 245 3.21 3.570 15.84 0 0 3 4
## Merc 240D 4 146.7 62 3.69 3.190 20.00 1 0 4 2
## Merc 230 4 140.8 95 3.92 3.150 22.90 1 0 4 2
## Merc 280 6 167.6 123 3.92 3.440 18.30 1 0 4 4
## Merc 450SE 8 275.8 180 3.07 4.070 17.40 0 0 3 3
## Merc 450SL 8 275.8 180 3.07 3.730 17.60 0 0 3 3
## Merc 450SLC 8 275.8 180 3.07 3.780 18.00 0 0 3 3
## Lincoln Continental 8 460.0 215 3.00 5.424 17.82 0 0 3 4
## Fiat 128 4 78.7 66 4.08 2.200 19.47 1 1 4 1
## Honda Civic 4 75.7 52 4.93 1.615 18.52 1 1 4 2
## Toyota Corona 4 120.1 97 3.70 2.465 20.01 1 0 3 1
## Dodge Challenger 8 318.0 150 2.76 3.520 16.87 0 0 3 2
## AMC Javelin 8 304.0 150 3.15 3.435 17.30 0 0 3 2
## Pontiac Firebird 8 400.0 175 3.08 3.845 17.05 0 0 3 2
## Fiat X1-9 4 79.0 66 4.08 1.935 18.90 1 1 4 1
## Porsche 914-2 4 120.3 91 4.43 2.140 16.70 0 1 5 2
## Ferrari Dino 6 145.0 175 3.62 2.770 15.50 0 1 5 6
## Maserati Bora 8 301.0 335 3.54 3.570 14.60 0 1 5 8
##
## $y
## [1] 21.0 21.0 18.1 14.3 24.4 22.8 19.2 16.4 17.3 15.2 10.4 32.4 30.4 21.5 15.5
## [16] 15.2 19.2 27.3 26.0 19.7 15.0
##
## $n_updates
## [1] 0
##
## $avg_coefs
## [,1]
## cyl -0.36321389
## disp -0.41963495
## hp -0.47508218
## drat 0.55766798
## wt -0.54482177
## qsec 0.09679212
## vs 0.32281808
## am 0.48049052
## gear 0.25888698
## carb -0.47461167
## h1 -0.32480888
## h2 0.75205757
## h3 -0.35831420
## h4 0.04812892
## h5 0.72081638
##
## attr(,"class")
## [1] "rvfl"
# 3 - Predict on validation set -----
preds_validation <- bayesianrvfl::predict_rvfl(fit_obj,
newx = X_validation)
level <- 95
multiplier <- qnorm(1 - (100 - level)/200)
summary(preds_validation$mean - y_validation)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -3.4992 -1.9268 2.2022 0.8209 3.3909 3.9374
preds_upper <- preds_validation$mean + multiplier*preds_validation$sd
preds_lower <- preds_validation$mean - multiplier*preds_validation$sd
# coverage rate
mean((preds_upper >= y_validation)*(preds_lower <= y_validation))
## [1] 1
# 4 - Update -----
# add new points in an online fashion
fit_obj2 <- bayesianrvfl::update_params(fit_obj, newx = X[idx_test[1], ],
newy = y[idx_test[1]], method="polyak")
fit_obj3 <- bayesianrvfl::update_params(fit_obj2, newx = X[idx_test[2], ],
newy = y[idx_test[2]], method="polyak")
fit_obj4 <- bayesianrvfl::update_params(fit_obj3, newx = X[idx_test[3], ],
newy = y[idx_test[3]], method="polyak")
fit_obj5 <- bayesianrvfl::update_params(fit_obj4, newx = X[idx_test[4], ],
newy = y[idx_test[4]], method="polyak")
fit_obj6 <- bayesianrvfl::update_params(fit_obj5, newx = X[idx_test[5], ],
newy = y[idx_test[5]], method="polyak")
fit_obj7 <- bayesianrvfl::update_params(fit_obj6, newx = X[idx_test[6], ],
newy = y[idx_test[6]], method="polyak")
(mat_coefs <- cbind(fit_obj$coef, fit_obj2$coef,
fit_obj3$coef, fit_obj4$coef,
fit_obj5$coef, fit_obj6$coef,
fit_obj7$coef))
## [,1] [,2] [,3] [,4] [,5] [,6]
## cyl -0.36321389 0.156616913 0.15129821 -0.005704671 0.005046952 -0.4791379
## disp -0.41963495 -0.002381108 0.03952438 -0.150807825 -0.095685593 -0.6202010
## hp -0.47508218 -0.183146989 -0.22741808 -0.307549435 -0.279807464 -0.7177286
## drat 0.55766798 0.347876182 0.25447984 0.359539343 0.281360219 0.7520789
## wt -0.54482177 -0.120830328 -0.11186749 -0.162639679 -0.201756336 -0.8419310
## qsec 0.09679212 -0.056125629 0.02836236 0.107460728 0.049437217 0.4602919
## vs 0.32281808 -0.202146379 -0.07869073 0.044916974 -0.086675154 0.3405753
## am 0.48049052 -0.044473938 -0.12994323 -0.016944264 0.069203005 0.5741591
## gear 0.25888698 0.066277686 -0.03380930 0.097803653 0.037878894 0.2366640
## carb -0.47461167 -0.017390995 -0.12581946 -0.065429630 -0.150177460 -0.5469287
## h1 -0.32480888 0.059058223 0.02886339 -0.015250897 0.004325417 -0.3459104
## h2 0.75205757 0.316744350 0.23962242 0.342580340 0.422119790 1.1714129
## h3 -0.35831420 -0.041448067 -0.12036108 -0.121029370 -0.037882410 -0.3151740
## h4 0.04812892 0.391569783 0.52641175 0.216096557 0.196435684 -0.1269167
## h5 0.72081638 -0.055084236 -0.03157788 0.071551387 0.151658598 1.0546296
## [,7]
## cyl -0.4739488
## disp -0.6144207
## hp -0.7098485
## drat 0.7533529
## wt -0.8378457
## qsec 0.4531890
## vs 0.3363085
## am 0.5705156
## gear 0.2322069
## carb -0.5433451
## h1 -0.3379071
## h2 1.1681003
## h3 -0.3144344
## h4 -0.1247607
## h5 1.0513688
preds_validation2 <- bayesianrvfl::predict_rvfl(fit_obj2,
newx = X_validation)
preds_validation3 <- bayesianrvfl::predict_rvfl(fit_obj7,
newx = X_validation)
# 5 - Plots -----
plot(x = log(obj_GCV$lambda), y = obj_GCV$GCV, type='l',
main = 'Generalized Cross-validation error',
xlab = "log(lambda)", ylab = "GCV")

plot(y_validation, type='l', col="red",
lwd=2,
ylim = c(min(c(y_validation, preds_lower, preds_validation$mean)),
max(c(y_validation, preds_upper, preds_validation$mean))),
main = 'Out-of-sample credible intervals',
xlab = "obs#", ylab = "prediction")
lines(preds_validation$mean, col="blue", lwd=2)
lines(preds_upper, col="gray")
lines(preds_lower, col="gray")

plot(x = y_validation,
y = preds_validation$mean,
ylim = c(min(c(y_validation, preds_lower, preds_validation$mean)),
max(c(y_validation, preds_upper, preds_validation$mean))),
main = 'observed vs predicted \n (before updates)',
xlab = "observed", ylab = "predicted")
abline(a = 0, b = 1, col="green", lwd=2)

matplot(t(mat_coefs), type = 'l', lwd = 2,
main = 'model coefficients \n after each update',
xlab = "update#", ylab = "model coefficient")

plot(x = y_validation,
y = preds_validation2$mean,
ylim = c(min(c(y_validation, preds_lower, preds_validation$mean)),
max(c(y_validation, preds_upper, preds_validation$mean))),
main = 'observed vs predicted \n (after 6 point-updates)',
xlab = "observed", ylab = "predicted")
abline(a = 0, b = 1, col="green", lwd=2)

matplot(t(preds_validation3$simulate(250)), type='l',
lwd = 2, main = 'predictive posterior simulation \n (after 6 point-updates)',
xlab = "obs#", ylab = "prediction")
