Skip to content

Commit de7dbe3

Browse files
authored
Merge pull request #102 from edgararuiz/fix-test
Fixes earth GLM models and prediction test routine
2 parents 6a0a4e7 + eeeddd6 commit de7dbe3

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# tidypredict (development version)
22

3+
- Fixes issue handling GLM Binomial earth models (#97)
4+
35
- Adds capability to handle single simple Cubist models (#57)
46

57
- Fixed parenthesis issue in the creation of the interval formula (#76)

R/model-earth.R

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@ parse_model.earth <- function(model) {
1414
}
1515

1616
is_glm <- !is.null(model$glm.list)
17-
18-
all_coefs <- model$coefficients
19-
if (is_glm) all_coefs <- model$glm.coefficients
20-
17+
2118
pm <- list()
2219
pm$general$model <- "earth"
2320
pm$general$type <- "tree"
@@ -29,12 +26,12 @@ parse_model.earth <- function(model) {
2926
pm$general$family <- fam$family
3027
pm$general$link <- fam$link
3128
}
32-
pm$terms <- mars_terms(model)
29+
pm$terms <- mars_terms(model, is_glm)
3330
as_parsed_model(pm)
3431
}
3532

3633

37-
mars_terms <- function(mod) {
34+
mars_terms <- function(mod, is_glm) {
3835
feature_types <-
3936
tibble::as_tibble(mod$dirs, rownames = "feature") %>%
4037
dplyr::mutate(feature_num = dplyr::row_number()) %>%
@@ -48,9 +45,16 @@ mars_terms <- function(mod) {
4845
tidyr::pivot_longer(cols = c(-feature, -feature_num),
4946
values_to = "value",
5047
names_to = "term")
48+
49+
if (is_glm) {
50+
all_coefs <- mod$glm.coefficients
51+
} else {
52+
all_coefs <- mod$coefficients
53+
}
54+
5155
feature_coefs <-
5256
# Note coef(mod) formats data differently for logistic regression
53-
tibble::as_tibble(mod$coefficients, rownames = "feature") %>%
57+
tibble::as_tibble(all_coefs, rownames = "feature") %>%
5458
setNames(c("feature", "coefficient"))
5559

5660
term_to_column <-

R/test-predictions.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ tidypredict_test_default <- function(model, df = model$model, threshold = 0.0000
6868

6969
if (is.numeric(max_rows)) df <- head(df, max_rows)
7070

71-
base <- predict(model, df, interval = interval, type = "response")
71+
preds <- predict(model, df, interval = interval, type = "response")
7272

7373
if (!include_intervals) {
74-
base <- data.frame(fit = base, row.names = NULL)
74+
base <- data.frame(fit = as.vector(preds), row.names = NULL)
7575
} else {
76-
base <- as.data.frame(base)
76+
base <- as.data.frame(preds)
7777
}
7878

7979
te <- tidypredict_to_column(

0 commit comments

Comments
 (0)