Fits a `glmnet` penalized regression model with a consistent interface. Supports regression and binary classification.

wrap_glmnet(x, y, ...)

# S3 method for class 'wrap_glmnet'
predict(object, newx, type = c("class", "prob"), s = NULL, ...)

# S3 method for class 'wrap_glmnet'
print(x, ...)

Arguments

x

A matrix or data.frame of features.

y

A factor or character vector for classification, numeric for regression.

...

Additional arguments passed to [glmnet::glmnet()]. Pass `family = "binomial"` for binary classification.

object

A fitted `wrap_glmnet` object.

newx

A matrix or data.frame of new observations.

type

`"class"` (default) for class labels, `"prob"` for a probability matrix. Ignored for regression.

s

Lambda value for prediction. Defaults to the midpoint of the lambda path. Pass `s = cv_fit$lambda.min` if using [glmnet::cv.glmnet()].

Value

An object of class `wrap_glmnet` with fields:

fit

The fitted glmnet model.

levels

Class levels (NULL for regression).

task

"classification" or "regression".

Note

Multiclass (`family = "multinomial"`) is not yet supported. For lambda selection, a specific `s` value can be passed to `predict()`. By default the midpoint of the lambda path is used. For optimal lambda, use [glmnet::cv.glmnet()] externally and pass `s = fit$lambda.min`.

Examples

# \donttest{
X <- as.matrix(iris[iris$Species != "virginica", 1:4])
y <- droplevels(iris[iris$Species != "virginica", "Species"])
mod <- wrap_glmnet(X, y, family = "binomial")
predict(mod, newx = X, type = "class")
#>   [1] setosa     setosa     setosa     setosa     setosa     setosa    
#>   [7] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [13] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [19] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [25] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [31] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [37] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [43] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [49] setosa     setosa     versicolor versicolor versicolor versicolor
#>  [55] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [61] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [67] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [73] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [79] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [85] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [91] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [97] versicolor versicolor versicolor versicolor
#> Levels: setosa versicolor
predict(mod, newx = X, type = "prob")
#>          setosa  versicolor
#> 1   0.987162476 0.012837524
#> 2   0.978222799 0.021777201
#> 3   0.985149785 0.014850215
#> 4   0.976740818 0.023259182
#> 5   0.988456017 0.011543983
#> 6   0.976128195 0.023871805
#> 7   0.981411182 0.018588818
#> 8   0.983044703 0.016955297
#> 9   0.975810653 0.024189347
#> 10  0.982120067 0.017879933
#> 11  0.987661643 0.012338357
#> 12  0.979869911 0.020130089
#> 13  0.983265177 0.016734823
#> 14  0.990028450 0.009971550
#> 15  0.994673874 0.005326126
#> 16  0.990030689 0.009969311
#> 17  0.987995263 0.012004737
#> 18  0.983274420 0.016725580
#> 19  0.979605634 0.020394366
#> 20  0.985537552 0.014462448
#> 21  0.976115098 0.023884902
#> 22  0.979073715 0.020926285
#> 23  0.994230941 0.005769059
#> 24  0.942528486 0.057471514
#> 25  0.966450029 0.033549971
#> 26  0.969383865 0.030616135
#> 27  0.966045113 0.033954887
#> 28  0.984746751 0.015253249
#> 29  0.985726081 0.014273919
#> 30  0.975160545 0.024839455
#> 31  0.972418772 0.027581228
#> 32  0.971336487 0.028663513
#> 33  0.993825975 0.006174025
#> 34  0.993910528 0.006089472
#> 35  0.976740818 0.023259182
#> 36  0.987502490 0.012497510
#> 37  0.989199807 0.010800193
#> 38  0.991150391 0.008849609
#> 39  0.981652500 0.018347500
#> 40  0.983044703 0.016955297
#> 41  0.985919990 0.014080010
#> 42  0.950668719 0.049331281
#> 43  0.985149785 0.014850215
#> 44  0.948761636 0.051238364
#> 45  0.962803613 0.037196387
#> 46  0.971704822 0.028295178
#> 47  0.986813334 0.013186666
#> 48  0.982362086 0.017637914
#> 49  0.987661643 0.012338357
#> 50  0.984131552 0.015868448
#> 51  0.006874998 0.993125002
#> 52  0.007452406 0.992547594
#> 53  0.003339010 0.996660990
#> 54  0.011566372 0.988433628
#> 55  0.004083623 0.995916377
#> 56  0.008286700 0.991713300
#> 57  0.004485214 0.995514786
#> 58  0.090226617 0.909773383
#> 59  0.007751134 0.992248866
#> 60  0.016117641 0.983882359
#> 61  0.043493251 0.956506749
#> 62  0.010129093 0.989870907
#> 63  0.022977422 0.977022578
#> 64  0.004989001 0.995010999
#> 65  0.042965873 0.957034127
#> 66  0.010396118 0.989603882
#> 67  0.006019142 0.993980858
#> 68  0.032695383 0.967304617
#> 69  0.002555587 0.997444413
#> 70  0.028718141 0.971281859
#> 71  0.001981559 0.998018441
#> 72  0.019638465 0.980361535
#> 73  0.001754453 0.998245547
#> 74  0.007645149 0.992354851
#> 75  0.013028213 0.986971787
#> 76  0.009346250 0.990653750
#> 77  0.003766236 0.996233764
#> 78  0.001474221 0.998525779
#> 79  0.005408873 0.994591127
#> 80  0.079762661 0.920237339
#> 81  0.030657663 0.969342337
#> 82  0.046960398 0.953039602
#> 83  0.027263469 0.972736531
#> 84  0.001172949 0.998827051
#> 85  0.006019142 0.993980858
#> 86  0.007067079 0.992932921
#> 87  0.004730437 0.995269563
#> 88  0.005780313 0.994219687
#> 89  0.020426861 0.979573139
#> 90  0.014301526 0.985698474
#> 91  0.010390339 0.989609661
#> 92  0.006606123 0.993393877
#> 93  0.020694951 0.979305049
#> 94  0.081781170 0.918218830
#> 95  0.012521725 0.987478275
#> 96  0.022387225 0.977612775
#> 97  0.015479258 0.984520742
#> 98  0.013028213 0.986971787
#> 99  0.124852867 0.875147133
#> 100 0.016539899 0.983460101
# }