Skip to content

Commit

Permalink
fix another bug in roc/prc vis (fixes #79)
Browse files Browse the repository at this point in the history
  • Loading branch information
mllg committed Aug 11, 2021
1 parent 0dc09f4 commit 8200e6f
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 14 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3viz
Title: Visualizations for 'mlr3'
Version: 0.5.4
Version: 0.5.5
Authors@R:
c(person(given = "Michel",
family = "Lang",
Expand Down
7 changes: 5 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# mlr3viz 0.5.3.9000
# mlr3viz 0.5.5

- Fixed a bug for ROC- and Precision-recall-curves (#72, #75).
- Fixed another bug for ROC- and Precision-recall-curves (#79).

# mlr3viz 0.5.4

- Fixed a bug for ROC- and Precision-recall-curves (#72, #75).

# mlr3viz 0.5.3

Expand Down
16 changes: 9 additions & 7 deletions R/as_precrec.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ roc_data = function(prediction) {
}

data.table(
scores = prediction$prob[, 1L],
scores = prediction$prob[, 1L, drop = TRUE],
labels = prediction$truth
)
}
Expand All @@ -42,7 +42,7 @@ as_precrec.PredictionClassif = function(object) { # nolint
scores = data$scores,
labels = data$labels,
dsids = 1L,
posclass = levels(data$labels)[1L]
posclass = levels(object$truth)[1L]
)
}

Expand All @@ -53,10 +53,12 @@ as_precrec.ResampleResult = function(object) { # nolint
require_namespaces("precrec")
predictions = object$predictions()
data = transpose_list(map(predictions, roc_data))

precrec::mmdata(
scores = data$scores, labels = data$labels,
scores = data$scores,
labels = data$labels,
dsids = seq_along(predictions),
posclass = levels(data$labels)[1L]
posclass = object$task$positive
)
}

Expand All @@ -82,10 +84,10 @@ as_precrec.BenchmarkResult = function(object) { # nolint
lrns = unique(scores$learner_id)
iters = unique(scores$iteration)
precrec::mmdata(
data$scores,
data$labels,
scores = data$scores,
labels = data$labels,
dsids = iters,
modnames = lrns,
posclass = levels(data$labels)[1L]
posclass = object$tasks$task[[1L]]$positive
)
}
4 changes: 4 additions & 0 deletions tests/testthat/test_BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,8 @@ test_that("holdout roc plot (#54)", {
bmr = benchmark(design)
p = autoplot(bmr, type = "roc")
expect_true(is.ggplot(p))

# roc is not inverted?
tab = as.data.table(precrec::auc(precrec::evalmod(as_precrec(bmr))))
expect_number(tab[modnames == "classif.rpart" & curvetypes == "ROC", aucs], lower = 0.5)
})
14 changes: 10 additions & 4 deletions tests/testthat/test_PredictionClassif.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
test_that("autoplot.PredictionClassif", {
task = mlr3::tsk("sonar")
learner = mlr3::lrn("classif.rpart", predict_type = "prob")$train(task)
prediction = learner$predict(task)
task = mlr3::tsk("sonar")
learner = mlr3::lrn("classif.rpart", predict_type = "prob")$train(task)
prediction = learner$predict(task)

test_that("autoplot.PredictionClassif", {
p = autoplot(prediction, type = "stacked")
expect_true(is.ggplot(p))

Expand All @@ -15,3 +15,9 @@ test_that("autoplot.PredictionClassif", {
p = autoplot(prediction, type = "threshold")
expect_true(is.ggplot(p))
})

test_that("roc is not inverted", {
skip_if_not_installed("precrec")
tab = as.data.table(precrec::auc(precrec::evalmod(as_precrec(prediction))))
expect_numeric(tab[curvetypes == "ROC", aucs], len = 1L, lower = 0.5)
})
8 changes: 8 additions & 0 deletions tests/testthat/test_ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,11 @@ test_that("autoplot ResampleResult type=prediction", {
regexp = "Plot learner prediction only works with one or two features for
regression!")
})


test_that("roc is not inverted", {
autoplot(rr, type = "roc")
skip_if_not_installed("precrec")
tab = as.data.table(precrec::auc(precrec::evalmod(as_precrec(rr))))
expect_number(mean(tab[curvetypes == "ROC", aucs]), lower = 0.5)
})

0 comments on commit 8200e6f

Please sign in to comment.