Skip to content

Commit

Permalink
add precision test
Browse files Browse the repository at this point in the history
  • Loading branch information
rrrrn committed Nov 7, 2023
1 parent cbbbe24 commit 08b714f
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 48 deletions.
42 changes: 29 additions & 13 deletions R/accuracy.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ binary_acc <- function(preds, target, threshold=0.5, multidim_average = "global"
#' @param average Defines the reduction that is applied over labels.
#' Micro-sum over all class labels.
#' Macro-calculate class label-wise statistics and then take the average.
#' The parameter average makes an effect in calculation only when the accuracy
#' is required on a global level
#' @param multidim_average Average model: global-average across all accuracies,
#' samplewise-average across the all but the first dimensions (calculated
#' independently for each sample)
Expand Down Expand Up @@ -83,20 +81,38 @@ multiclass_acc <- function(preds, target, multidim_average = "global",
stopifnot(dim(preds)==dim(target))
stopifnot(num_class>0)

if(multidim_average=="global"&average=="micro"){
return(mean(preds==target))
}
else if(multidim_average=="global"&average=="macro"){
label_acc = numeric(num_class)
# label-wise accuracy calculation
for(i in 1:num_class){
targetnew <- target[target==ele_all[i]]
predsnew <- preds[target==ele_all[i]]
label_acc[i] <- ifelse(length(targetnew==predsnew)>0, mean(targetnew==predsnew),0)
# generalized steps for computing scores
comp_assist = function(datamtx, average){
if(length(dim(datamtx))==1|is.null(dim(datamtx))){
n = length(datamtx)/2
preds = datamtx[1:n]
target = datamtx[(n+1):(2*n)]
}
else{
n = ncol(datamtx)/2
preds = datamtx[,1:n]
target = datamtx[,(n+1):(2*n)]
}

if(average=="micro"){
return(mean(preds==target))
}
else if(average=="macro"){
label_acc = numeric(num_class)
# label-wise accuracy calculation
for(i in 1:num_class){
targetnew <- target[target==ele_all[i]]
predsnew <- preds[target==ele_all[i]]
label_acc[i] <- ifelse(length(targetnew==predsnew)>0, mean(targetnew==predsnew),0)
}
}
return(mean(label_acc))
}

if(multidim_average=="global"){
return(comp_assist(cbind(preds,target), average))
}
else if(multidim_average=="samplewise"){
return(apply(target==preds, 1, mean))
return(apply(cbind(preds,target), 1, comp_assist, average = average))
}
}
25 changes: 13 additions & 12 deletions R/confusion_scores.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' confusion_scores
#'
#' @description Calculate confusion matrix for a given predicted set of
#' values and corresponding targets, mainly suitable for binary classification task.
#' values and corresponding targets, mainly suitable for binary classtypeification task.
#'
#' @param preds Predicted label, same shape as target label
#' @param target Target label
Expand Down Expand Up @@ -64,7 +64,7 @@ confusion_scores <- function(preds, target, multidim_average="global"){
#' @param multidim_average Average model: global-average across all accuracies,
#' samplewise-average across the all but the first dimensions (calculated
#' independently for each sample)
#' @param class If multidim_average is set to "samplewise", this param specifies
#' @param classtype If multidim_average is set to "samplewise", this param specifies
#' particular class of interest to compute confusion matrix
#'
#' @return If under "global" average mode, the cross-sample multiclass confusion
Expand All @@ -76,8 +76,8 @@ confusion_scores <- function(preds, target, multidim_average="global"){
#'
#' y_pred = c("A","B","C","A","B")
#' y_target = rep("A", 5)
#' multiclass_confusion_scores(y_pred, y_target, class="A")
multiclass_confusion_scores <- function(preds, target, class=NULL,
#' multiclass_confusion_scores(y_pred, y_target, classtype="A")
multiclass_confusion_scores <- function(preds, target, classtype=NULL,
multidim_average = "global"){

ele_all <- factor(unique(c(target, preds))) # element in the union of two vec
Expand All @@ -88,22 +88,23 @@ multiclass_confusion_scores <- function(preds, target, class=NULL,
preds <- factor(preds, levels = ele_all)
target <- factor(target, levels = ele_all)
cfs_mtx <- table(preds, target)
if(length(class)==0){
if(length(classtype)==0|length(ele_all)==1){
return(cfs_mtx)
}
else if(length(class)==1){
tp = cfs_mtx[class, class]
fp = sum(cfs_mtx[class, ])-tp
fn = sum(cfs_mtx[, class])-tp
else if(length(classtype)==1){
classtype=as.character(classtype)
tp = cfs_mtx[classtype, classtype]
fp = sum(cfs_mtx[classtype, ])-tp
fn = sum(cfs_mtx[, classtype])-tp
tn = sum(cfs_mtx)-tp-fp-fn
cfsmtx = matrix(c(tp, fn, fp, tn), 2, 2)
return(list(matrix = cfsmtx, tp=tp, fn=fn, fp=fp, tn=tn))
}
}
else if(multidim_average=="samplewise"&length(class)==1){
else if(multidim_average=="samplewise"&length(classtype)==1){
dimpred = dim(preds)
prednew <- matrix(as.numeric(preds==class), dimpred[1], dimpred[2])
targetnew <- matrix(as.numeric(target==class), dimpred[1], dimpred[2])
prednew <- matrix(as.numeric(preds==classtype), dimpred[1], dimpred[2])
targetnew <- matrix(as.numeric(target==classtype), dimpred[1], dimpred[2])
return(confusion_scores(prednew, targetnew, multidim_average = multidim_average))
}
else{
Expand Down
118 changes: 96 additions & 22 deletions R/precision.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@
binary_precision <- function(preds, target, threshold=0.5,multidim_average = "global"){
#' binary_precision
#'
#' @description Calculate the binary classtypeification precision for a given predicted set of
#' values and corresponding targets. In other words, this function estimate how accurate
#' the true prediction value by the model is.
#'
#' @param preds Predicted labels or predicted probability between 0 and 1,
#' same shape as target label
#' @param target Target label
#' @param threshold The numerical cut-off between 0 and 1 to transform
#' predicted probability into binary predicted labels
#' @param multidim_average Average model: global-average across all accuracies,
#' samplewise-average across the all but the first dimensions (calculated
#' independently for each sample)
#'
#' @return Binary precision value for preds and target, with format dictated by
#' multidim_average command.
#'
#' @export
#'
#' @examples
#' binary_precision(c(0.8, 0.2), c(1,1), 0.3)
#' binary_precision(c(1,1), c(0,1))
binary_precision <- function(preds, target, threshold=0.5, multidim_average = "global"){

stopifnot(dim(preds)==dim(target))

# transform probability into labels when necessary
Expand All @@ -7,45 +31,95 @@ binary_precision <- function(preds, target, threshold=0.5,multidim_average = "gl
}

cfs_mtx <- confusion_scores(preds, target, multidim_average)
if(any((cfs_mtx$tp+cfs_mtx$fp==0))){
warning("NaN generated due to lack of positively predicted labels")
}
return((cfs_mtx$tp)/(cfs_mtx$tp+cfs_mtx$fp))
}

#' multiclass_precision
#'
#' @description Calculate the multiclass precision value for a given predicted set of
#' values and corresponding targets
#'
#' @param preds Predicted label with the same shape as target label, or
#' predicted probability between 0 and 1 for each class that has one
#' additional dimension compared with target label
#' @param target Target label that has been transformed into dinstinct integers
#' to refer to each class
#' @param average Defines the reduction that is applied over labels.
#' Micro-sum over all class labels, that is all true positives for each class divided
#' by all positive predicted values for each class.
#' Macro-calculate class label-wise precision scores and then take the average.
#' @param multidim_average Average model: global-average across all precision scores,
#' samplewise-average across the all but the first dimensions (calculated
#' independently for each sample)
#'
#' @return Multiclass precision for preds and target, with format dictated by
#' multidim_average argument and average methods choice.
#'
#' @export
#'
#' @examples
#' y_pred = matrix(c(0.1, 0.5, 0.4, 0.9, 0.2, 0.8), 2,3)
#' y_target = c(2,1)
#' multiclass_precision(y_pred, y_target)
multiclass_precision <-function(preds, target, multidim_average = "global",
average = "micro"){
# transform probability into labels when necessary
if((length(dim(preds))==length(dim(target))+1)){
if((length(dim(preds))==length(dim(target))+1)|(length(dim(preds))>=2&is.null(dim(target)))){
# the last dimension always be the probabilities for each class
preds = apply(preds, 1:(length(dim(preds))-1), which.max)
}

# validate the multiclass assumption
ele_all <- unique(c(target, preds)) # element in the union of two vec
stopifnot(length(ele_all)>=0)
stopifnot(length(target)==length(preds))
num_class = length(ele_all)
ele_all <- unique(c(preds, target))
num_class <- length(ele_all)

stopifnot(dim(preds)[1]==dim(target)[1])
stopifnot(dim(preds)==dim(target))
stopifnot(num_class>0)

if(length(ele_all)==1){
tp <- length(preds)
tn <- fp <- fn <-0
}

stopifnot(dim(preds)==dim(target))
stopifnot(num_class>=length(unique(c(target))))
# generalized steps for computing scores
comp_assist = function(datamtx, average){
if(length(dim(datamtx))==1|is.null(dim(datamtx))){
n = length(datamtx)/2
preds = datamtx[1:n]
target = datamtx[(n+1):(2*n)]
}
else{
n = ncol(datamtx)/2
preds = datamtx[,1:n]
target = datamtx[,(n+1):(2*n)]
}

if(multidim_average=="samplewise"|average=="micro"){
precision_0 = (binary_precision(preds, target, multidim_average = multidim_average))
return(precision_0)
}
else if(average=="macro"){
label_precision = numeric(num_class)
# label-wise accuracy calculation
i = 1
for(ele in ele_all){
targetnew <- target[target==ele]
predsnew <- preds[target==ele]
label_precision[i] <- multiclass_confusion_scores(predsnew, targetnew, class=ele)
i = i+1
if(average=="micro"){
cfsmtx <- multiclass_confusion_scores(preds, target)
tp <- sum(diag(cfsmtx))
return((tp/sum(cfsmtx)))
}
else if(average=="macro"){
label_prec = numeric(num_class)
# label-wise accuracy calculation
for(i in 1:num_class){
cfsmtx <- multiclass_confusion_scores(preds, target, classtype=ele_all[i])
label_prec[i] <- cfsmtx$tp/(cfsmtx$tp+cfsmtx$fp)
}
return(mean(label_prec))
}
return(mean(label_acc))
}


if(multidim_average=="global"){
return(comp_assist(cbind(preds,target), average))
}
else if(multidim_average=="samplewise"){
return(apply(cbind(preds,target), 1, comp_assist, average = average))
}

}
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

<!-- badges: start -->
[![R-CMD-check](https://github.com/rrrrn/mmetrics/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/rrrrn/mmetrics/actions/workflows/R-CMD-check.yaml)
[![Codecov test coverage](https://codecov.io/gh/rrrrn/mmetrics/branch/main/graph/badge.svg)](https://app.codecov.io/gh/rrrrn/mmetrics?branch=main)
<!-- badges: end -->
14 changes: 14 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
comment: false

coverage:
status:
project:
default:
target: auto
threshold: 1%
informational: true
patch:
default:
target: auto
threshold: 1%
informational: true
5 changes: 5 additions & 0 deletions tests/testthat/test-confusion_scores.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,9 @@ test_that("multiclass confusion_scores", {
target = t(preds)
result = multiclass_confusion_scores(preds, target, multidim_average = "samplewise", class=1)
expect_equal(result$tp, c(2, 0, 0, 2))

preds = rep(1,4)
target = preds
result = multiclass_confusion_scores(preds, target)
expect_equal(result[[1]], 4)
})
2 changes: 1 addition & 1 deletion tests/testthat/test-multiclass_acc.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ test_that("multiclass_acc function", {

preds = matrix(c(1,2,3,1,1,3),2,3)
target = matrix(c(1,1,3,3,1,1),2,3)
result = multiclass_acc(preds, target, multidim_average = "samplewise")
result = multiclass_acc(preds, target, multidim_average = "samplewise", average="micro")
expect_equal(result, c(1, 0))
})
47 changes: 47 additions & 0 deletions tests/testthat/test-precision.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
test_that("binary precision", {
preds = c(1,1,1,1)
target = c(1,0,1,1)
result = binary_precision(preds, target)
expect_equal(result, 3/4)

preds = c(0,0,0,0.8)
target = c(1,0,1,1)
result = binary_precision(preds, target, threshold=0.7)
expect_equal(result, 1)

preds = matrix(c(1,1,0,1,0,1),2,3)
target = matrix(c(1,1,0, 0,0,1),2,3)
result = binary_precision(preds, target, multidim_average = "samplewise")
expect_equal(result, c(1,2/3))
})

test_that("multiclass_precision precision", {
preds = seq(1, 5)
target = rep(2, 5)
result1 = multiclass_precision(preds, target)
expect_equal(result1, .2)

preds = c(1,1,2,1,5)
target = c(1,1,2,2,5)
result1 = multiclass_precision(preds, target)
result2 = multiclass_precision(preds, target, average = "macro")
expect_equal(result1, .8)
expect_equal(result2, 8/9)

preds = matrix(c(.99, .1, .23, .5),2,2)
target = c(1,1)
result = multiclass_precision(preds, target)
expect_equal(result, .5)

preds = matrix(c(1,3,3,1,1,3),2,3)
target = matrix(c(1,1,3,3,1,3),2,3)
result = multiclass_precision(preds, target, multidim_average = "samplewise", average="micro")
result2 = multiclass_precision(preds, target, multidim_average = "samplewise", average="macro")
expect_equal(result, c(1, 1/3))
expect_equal(result2, c(1, 1/4))

preds = rep(1,4)
target = preds
result = multiclass_precision(preds, target)
expect_equal(result, 1)
})

0 comments on commit 08b714f

Please sign in to comment.