| Function | Works |
|---|---|
tidypredict_fit(), tidypredict_sql(),
parse_model() |
✔ |
tidypredict_to_column() |
✔ |
tidypredict_test() |
✔ |
tidypredict_interval(),
tidypredict_sql_interval() |
✗ |
parsnip |
✔ |
tidypredict_ functionslibrary(xgboost)
logregobj <- function(preds, dtrain) {
labels <- xgboost::getinfo(dtrain, "label")
preds <- 1 / (1 + exp(-preds))
grad <- preds - labels
hess <- preds * (1 - preds)
return(list(grad = grad, hess = hess))
}
xgb_bin_data <- xgboost::xgb.DMatrix(
as.matrix(mtcars[, -9]),
label = mtcars$am
)
model <- xgboost::xgb.train(
params = list(max_depth = 2, objective = "binary:logistic", base_score = 0.5),
data = xgb_bin_data,
nrounds = 50
)Create the R formula
tidypredict_fit(model)
#> 1 - 1/(1 + exp(case_when((wt >= 3.19000006 | is.na(wt)) ~ -0.436363667,
#> qsec < 19.4400005 & wt < 3.19000006 ~ 0.428571463, (qsec >=
#> 19.4400005 | is.na(qsec)) & wt < 3.19000006 ~ 0) + case_when(wt <
#> 3.1500001 ~ 0.311573088, hp < 230 & (wt >= 3.1500001 | is.na(wt)) ~
#> -0.392053694, (hp >= 230 | is.na(hp)) & (wt >= 3.1500001 |
#> is.na(wt)) ~ -0.0240745768) + case_when(gear < 4 ~ -0.355945677,
#> wt < 3.1500001 & (gear >= 4 | is.na(gear)) ~ 0.325712085,
#> (wt >= 3.1500001 | is.na(wt)) & (gear >= 4 | is.na(gear)) ~
#> -0.0384863913) + case_when(gear < 4 ~ -0.309683114, wt <
#> 3.1500001 & (gear >= 4 | is.na(gear)) ~ 0.283893973, (wt >=
#> 3.1500001 | is.na(wt)) & (gear >= 4 | is.na(gear)) ~ -0.032039877) +
#> case_when(gear < 4 ~ -0.275577009, wt < 3.1500001 & (gear >=
#> 4 | is.na(gear)) ~ 0.252453178, (wt >= 3.1500001 | is.na(wt)) &
#> (gear >= 4 | is.na(gear)) ~ -0.0266750772) + case_when(gear <
#> 4 ~ -0.248323873, qsec < 17.0499992 & (gear >= 4 | is.na(gear)) ~
#> 0.261978835, (qsec >= 17.0499992 | is.na(qsec)) & (gear >=
#> 4 | is.na(gear)) ~ -0.00959526002) + case_when(gear < 4 ~
#> -0.225384533, wt < 3.1500001 & (gear >= 4 | is.na(gear)) ~
#> 0.218285918, (wt >= 3.1500001 | is.na(wt)) & (gear >= 4 |
#> is.na(gear)) ~ -0.0373593047) + case_when(gear < 4 ~ -0.205454513,
#> qsec < 18.8999996 & (gear >= 4 | is.na(gear)) ~ 0.196076646,
#> (qsec >= 18.8999996 | is.na(qsec)) & (gear >= 4 | is.na(gear)) ~
#> -0.0544253439) + case_when(wt < 3.1500001 ~ 0.149246693,
#> qsec < 17.4200001 & (wt >= 3.1500001 | is.na(wt)) ~ 0.0354709327,
#> (qsec >= 17.4200001 | is.na(qsec)) & (wt >= 3.1500001 | is.na(wt)) ~
#> -0.226075932) + case_when(gear < 4 ~ -0.184417158, wt <
#> 3.1500001 & (gear >= 4 | is.na(gear)) ~ 0.176768288, (wt >=
#> 3.1500001 | is.na(wt)) & (gear >= 4 | is.na(gear)) ~ -0.0237750355) +
#> case_when(gear < 4 ~ -0.168993726, qsec < 18.6100006 & (gear >=
#> 4 | is.na(gear)) ~ 0.155569643, (qsec >= 18.6100006 |
#> is.na(qsec)) & (gear >= 4 | is.na(gear)) ~ -0.0325752236) +
#> case_when(wt < 3.1500001 ~ 0.119126029, (wt >= 3.1500001 |
#> is.na(wt)) ~ -0.105012275) + case_when(qsec < 17.2999992 ~
#> 0.117254697, (qsec >= 17.2999992 | is.na(qsec)) ~ -0.0994235724) +
#> case_when(wt < 3.19000006 ~ 0.097100094, (wt >= 3.19000006 |
#> is.na(wt)) ~ -0.10567718) + case_when(wt < 3.19000006 ~
#> 0.0824323222, (wt >= 3.19000006 | is.na(wt)) ~ -0.091120176) +
#> case_when(qsec < 17.6000004 ~ 0.0854752287, (qsec >= 17.6000004 |
#> is.na(qsec)) ~ -0.0764453933) + case_when(wt < 3.19000006 ~
#> 0.0749477893, (wt >= 3.19000006 | is.na(wt)) ~ -0.0799863264) +
#> case_when(qsec < 17.8199997 ~ 0.0728750378, (qsec >= 17.8199997 |
#> is.na(qsec)) ~ -0.0646049976) + case_when(wt < 3.19000006 ~
#> 0.0682478622, (wt >= 3.19000006 | is.na(wt)) ~ -0.0711427554) +
#> case_when(wt < 3.19000006 ~ 0.0579533465, (wt >= 3.19000006 |
#> is.na(wt)) ~ -0.0613371208) + case_when(qsec < 18.2999992 ~
#> 0.0595484748, (qsec >= 18.2999992 | is.na(qsec)) ~ -0.0546668135) +
#> case_when(wt < 3.19000006 ~ 0.0535288528, (wt >= 3.19000006 |
#> is.na(wt)) ~ -0.0558333211) + case_when(wt < 3.19000006 ~
#> 0.0454574414, (wt >= 3.19000006 | is.na(wt)) ~ -0.048143398) +
#> case_when(qsec < 18.6000004 ~ 0.0422042683, (qsec >= 18.6000004 |
#> is.na(qsec)) ~ -0.0454404354) + case_when(wt < 3.19000006 ~
#> 0.0420555957, (wt >= 3.19000006 | is.na(wt)) ~ -0.0449385941) +
#> case_when(qsec < 18.6000004 ~ 0.0393446013, (qsec >= 18.6000004 |
#> is.na(qsec)) ~ -0.0425945036) + case_when(wt < 3.19000006 ~
#> 0.0391179025, (wt >= 3.19000006 | is.na(wt)) ~ -0.0420661867) +
#> case_when(qsec < 18.5200005 ~ 0.0304145869, (qsec >= 18.5200005 |
#> is.na(qsec)) ~ -0.031833414) + case_when(wt < 3.19000006 ~
#> 0.0362136625, (wt >= 3.19000006 | is.na(wt)) ~ -0.038949281) +
#> case_when(qsec < 18.5200005 ~ 0.0295153651, (qsec >= 18.5200005 |
#> is.na(qsec)) ~ -0.0307046026) + case_when(drat < 3.8499999 ~
#> -0.0306891855, (drat >= 3.8499999 | is.na(drat)) ~ 0.0288283136) +
#> case_when(qsec < 18.5200005 ~ 0.0271221269, (qsec >= 18.5200005 |
#> is.na(qsec)) ~ -0.0281750448) + case_when(qsec < 18.5200005 ~
#> 0.0228891298, (qsec >= 18.5200005 | is.na(qsec)) ~ -0.0238814205) +
#> case_when(drat < 3.8499999 ~ -0.0296511576, (drat >= 3.8499999 |
#> is.na(drat)) ~ 0.0280048084) + case_when(qsec < 18.5200005 ~
#> 0.0214707125, (qsec >= 18.5200005 | is.na(qsec)) ~ -0.0224219449) +
#> case_when(qsec < 18.5200005 ~ 0.0181306079, (qsec >= 18.5200005 |
#> is.na(qsec)) ~ -0.0190209728) + case_when(wt < 3.19000006 ~
#> 0.0379650332, (wt >= 3.19000006 | is.na(wt)) ~ -0.0395050682) +
#> case_when(qsec < 18.5200005 ~ 0.0194106717, (qsec >= 18.5200005 |
#> is.na(qsec)) ~ -0.0202215631) + case_when(qsec < 18.5200005 ~
#> 0.0164139606, (qsec >= 18.5200005 | is.na(qsec)) ~ -0.0171694476) +
#> case_when(qsec < 18.5200005 ~ 0.013879573, (qsec >= 18.5200005 |
#> is.na(qsec)) ~ -0.0145772658) + case_when(qsec < 18.5200005 ~
#> 0.0117362784, (qsec >= 18.5200005 | is.na(qsec)) ~ -0.0123759825) +
#> case_when(wt < 3.19000006 ~ 0.0388614088, (wt >= 3.19000006 |
#> is.na(wt)) ~ -0.0400568396) + log(0.5/(1 - 0.5))))Add the prediction to the original table
library(dplyr)
mtcars %>%
tidypredict_to_column(model) %>%
glimpse()
#> Rows: 32
#> Columns: 12
#> $ mpg <dbl> 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2, 17.8,…
#> $ cyl <dbl> 6, 6, 4, 6, 8, 6, 8, 4, 4, 6, 6, 8, 8, 8, 8, 8, 8, 4, 4, 4, 4, 8,…
#> $ disp <dbl> 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 146.7, 140.8, 16…
#> $ hp <dbl> 110, 110, 93, 110, 175, 105, 245, 62, 95, 123, 123, 180, 180, 180…
#> $ drat <dbl> 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.92, 3.92, 3.92,…
#> $ wt <dbl> 2.620, 2.875, 2.320, 3.215, 3.440, 3.460, 3.570, 3.190, 3.150, 3.…
#> $ qsec <dbl> 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 20.00, 22.90, 18…
#> $ vs <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0,…
#> $ am <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,…
#> $ gear <dbl> 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 4, 3, 3,…
#> $ carb <dbl> 4, 4, 1, 1, 2, 1, 4, 2, 2, 4, 4, 3, 3, 3, 4, 4, 4, 1, 2, 1, 1, 2,…
#> $ fit <dbl> 0.98576418, 0.98576418, 0.93905137, 0.01081509, 0.04639094, 0.010…Confirm that tidypredict results match to the
model’s predict() results. The xg_df argument
expects the xgb.DMatrix data set.
Please be aware that xgboost converts data into 32-bit floats internally. This could possibly lead to splits being done incorrectly. Always verify that the predictions match up with model predictions. See this issue for more information.
parsnip fitted models are also supported by
tidypredict:
library(parsnip)
p_model <- boost_tree(mode = "regression") %>%
set_engine("xgboost") %>%
fit(am ~ ., data = mtcars)
#> Warning in check.deprecation(deprecated_train_params, match.call(), ...):
#> Passed invalid function arguments: nthread. These should be passed as a list to
#> argument 'params'. Conversion from argument to 'params' entry will be done
#> automatically, but this behavior will become an error in a future version.Here is an example of the model spec:
pm <- parse_model(model)
str(pm, 2)
#> List of 2
#> $ general:List of 7
#> ..$ model : chr "xgb.Booster"
#> ..$ type : chr "xgb"
#> ..$ params :List of 5
#> ..$ feature_names: chr [1:10] "mpg" "cyl" "disp" "hp" ...
#> ..$ niter : int 50
#> ..$ nfeatures : int 10
#> ..$ version : num 1
#> $ trees :List of 42
#> ..$ 0 :List of 3
#> ..$ 1 :List of 3
#> ..$ 2 :List of 3
#> ..$ 3 :List of 3
#> ..$ 4 :List of 3
#> ..$ 5 :List of 3
#> ..$ 6 :List of 3
#> ..$ 7 :List of 3
#> ..$ 8 :List of 3
#> ..$ 9 :List of 3
#> ..$ 10:List of 3
#> ..$ 11:List of 2
#> ..$ 12:List of 2
#> ..$ 13:List of 2
#> ..$ 14:List of 2
#> ..$ 15:List of 2
#> ..$ 16:List of 2
#> ..$ 17:List of 2
#> ..$ 18:List of 2
#> ..$ 19:List of 2
#> ..$ 20:List of 2
#> ..$ 21:List of 2
#> ..$ 22:List of 2
#> ..$ 23:List of 2
#> ..$ 24:List of 2
#> ..$ 25:List of 2
#> ..$ 26:List of 2
#> ..$ 27:List of 2
#> ..$ 28:List of 2
#> ..$ 29:List of 2
#> ..$ 30:List of 2
#> ..$ 31:List of 2
#> ..$ 32:List of 2
#> ..$ 33:List of 2
#> ..$ 34:List of 2
#> ..$ 35:List of 2
#> ..$ 36:List of 2
#> ..$ 37:List of 2
#> ..$ 38:List of 2
#> ..$ 39:List of 2
#> ..$ 40:List of 2
#> ..$ 41:List of 2
#> - attr(*, "class")= chr [1:3] "parsed_model" "pm_xgb" "list"str(pm$trees[1])
#> List of 1
#> $ 0:List of 3
#> ..$ :List of 2
#> .. ..$ prediction: num -0.436
#> .. ..$ path :List of 1
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "wt"
#> .. .. .. ..$ val : num 3.19
#> .. .. .. ..$ op : chr "less"
#> .. .. .. ..$ missing: logi TRUE
#> ..$ :List of 2
#> .. ..$ prediction: num 0.429
#> .. ..$ path :List of 2
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "qsec"
#> .. .. .. ..$ val : num 19.4
#> .. .. .. ..$ op : chr "more-equal"
#> .. .. .. ..$ missing: logi FALSE
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "wt"
#> .. .. .. ..$ val : num 3.19
#> .. .. .. ..$ op : chr "more-equal"
#> .. .. .. ..$ missing: logi FALSE
#> ..$ :List of 2
#> .. ..$ prediction: num 0
#> .. ..$ path :List of 2
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "qsec"
#> .. .. .. ..$ val : num 19.4
#> .. .. .. ..$ op : chr "less"
#> .. .. .. ..$ missing: logi TRUE
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "wt"
#> .. .. .. ..$ val : num 3.19
#> .. .. .. ..$ op : chr "more-equal"
#> .. .. .. ..$ missing: logi FALSE