Fits an `xgboost` model with a consistent interface. Supports binary classification, multiclass classification, and regression.

wrap_xgboost(x, y, ...)

# S3 method for class 'wrap_xgboost'
predict(object, newx, type = c("class", "prob"), ...)

# S3 method for class 'wrap_xgboost'
print(x, ...)

Arguments

x

A matrix or data.frame of features.

y

A factor or character vector for classification, numeric for regression.

...

Additional arguments passed to [xgboost::xgboost()]. The `objective` argument is required for classification (e.g. `"binary:logistic"`, `"multi:softprob"`).

object

A fitted `wrap_xgboost` object.

newx

A matrix or data.frame of new observations.

type

`"class"` (default) for class labels, `"prob"` for a probability matrix. Ignored for regression.

Value

An object of class `wrap_xgboost` with fields:

fit

The fitted xgboost model.

levels

Class levels (NULL for regression).

task

"classification" or "regression".

objective

The xgboost objective string, stored at fit time.

Examples

# \donttest{
X <- as.matrix(iris[iris$Species != "virginica", 1:4])
y <- droplevels(iris[iris$Species != "virginica", "Species"])
mod <- wrap_xgboost(X, y, nrounds = 50, objective = "binary:logistic", verbose = 0)
predict(mod, newx = X, type = "class")
#>   [1] setosa     setosa     setosa     setosa     setosa     setosa    
#>   [7] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [13] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [19] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [25] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [31] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [37] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [43] setosa     setosa     setosa     setosa     setosa     setosa    
#>  [49] setosa     setosa     versicolor versicolor versicolor versicolor
#>  [55] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [61] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [67] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [73] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [79] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [85] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [91] versicolor versicolor versicolor versicolor versicolor versicolor
#>  [97] versicolor versicolor versicolor versicolor
#> Levels: setosa versicolor
predict(mod, newx = X, type = "prob")
#>            setosa versicolor
#>   [1,] 0.98144253 0.01855747
#>   [2,] 0.98144253 0.01855747
#>   [3,] 0.98144253 0.01855747
#>   [4,] 0.98144253 0.01855747
#>   [5,] 0.98144253 0.01855747
#>   [6,] 0.98144253 0.01855747
#>   [7,] 0.98144253 0.01855747
#>   [8,] 0.98144253 0.01855747
#>   [9,] 0.98144253 0.01855747
#>  [10,] 0.98144253 0.01855747
#>  [11,] 0.98144253 0.01855747
#>  [12,] 0.98144253 0.01855747
#>  [13,] 0.98144253 0.01855747
#>  [14,] 0.98144253 0.01855747
#>  [15,] 0.98144253 0.01855747
#>  [16,] 0.98144253 0.01855747
#>  [17,] 0.98144253 0.01855747
#>  [18,] 0.98144253 0.01855747
#>  [19,] 0.98144253 0.01855747
#>  [20,] 0.98144253 0.01855747
#>  [21,] 0.98144253 0.01855747
#>  [22,] 0.98144253 0.01855747
#>  [23,] 0.98144253 0.01855747
#>  [24,] 0.98144253 0.01855747
#>  [25,] 0.98144253 0.01855747
#>  [26,] 0.98144253 0.01855747
#>  [27,] 0.98144253 0.01855747
#>  [28,] 0.98144253 0.01855747
#>  [29,] 0.98144253 0.01855747
#>  [30,] 0.98144253 0.01855747
#>  [31,] 0.98144253 0.01855747
#>  [32,] 0.98144253 0.01855747
#>  [33,] 0.98144253 0.01855747
#>  [34,] 0.98144253 0.01855747
#>  [35,] 0.98144253 0.01855747
#>  [36,] 0.98144253 0.01855747
#>  [37,] 0.98144253 0.01855747
#>  [38,] 0.98144253 0.01855747
#>  [39,] 0.98144253 0.01855747
#>  [40,] 0.98144253 0.01855747
#>  [41,] 0.98144253 0.01855747
#>  [42,] 0.98144253 0.01855747
#>  [43,] 0.98144253 0.01855747
#>  [44,] 0.98144253 0.01855747
#>  [45,] 0.98144253 0.01855747
#>  [46,] 0.98144253 0.01855747
#>  [47,] 0.98144253 0.01855747
#>  [48,] 0.98144253 0.01855747
#>  [49,] 0.98144253 0.01855747
#>  [50,] 0.98144253 0.01855747
#>  [51,] 0.01855749 0.98144251
#>  [52,] 0.01855749 0.98144251
#>  [53,] 0.01855749 0.98144251
#>  [54,] 0.01855749 0.98144251
#>  [55,] 0.01855749 0.98144251
#>  [56,] 0.01855749 0.98144251
#>  [57,] 0.01855749 0.98144251
#>  [58,] 0.01855749 0.98144251
#>  [59,] 0.01855749 0.98144251
#>  [60,] 0.01855749 0.98144251
#>  [61,] 0.01855749 0.98144251
#>  [62,] 0.01855749 0.98144251
#>  [63,] 0.01855749 0.98144251
#>  [64,] 0.01855749 0.98144251
#>  [65,] 0.01855749 0.98144251
#>  [66,] 0.01855749 0.98144251
#>  [67,] 0.01855749 0.98144251
#>  [68,] 0.01855749 0.98144251
#>  [69,] 0.01855749 0.98144251
#>  [70,] 0.01855749 0.98144251
#>  [71,] 0.01855749 0.98144251
#>  [72,] 0.01855749 0.98144251
#>  [73,] 0.01855749 0.98144251
#>  [74,] 0.01855749 0.98144251
#>  [75,] 0.01855749 0.98144251
#>  [76,] 0.01855749 0.98144251
#>  [77,] 0.01855749 0.98144251
#>  [78,] 0.01855749 0.98144251
#>  [79,] 0.01855749 0.98144251
#>  [80,] 0.01855749 0.98144251
#>  [81,] 0.01855749 0.98144251
#>  [82,] 0.01855749 0.98144251
#>  [83,] 0.01855749 0.98144251
#>  [84,] 0.01855749 0.98144251
#>  [85,] 0.01855749 0.98144251
#>  [86,] 0.01855749 0.98144251
#>  [87,] 0.01855749 0.98144251
#>  [88,] 0.01855749 0.98144251
#>  [89,] 0.01855749 0.98144251
#>  [90,] 0.01855749 0.98144251
#>  [91,] 0.01855749 0.98144251
#>  [92,] 0.01855749 0.98144251
#>  [93,] 0.01855749 0.98144251
#>  [94,] 0.01855749 0.98144251
#>  [95,] 0.01855749 0.98144251
#>  [96,] 0.01855749 0.98144251
#>  [97,] 0.01855749 0.98144251
#>  [98,] 0.01855749 0.98144251
#>  [99,] 0.01855749 0.98144251
#> [100,] 0.01855749 0.98144251
# }