How can I draw a ROC curve for a randomForest model with three classes in R?
I see two problems here 1) ROC curves work for binary classifiers, so you should convert your performance evaluation in a series of binary problems. I show below how to do this. 2) You should get the probabilities of each observation to belong to each of your classes (rather than just the predicted class) when you predict your test set. This will allow you to draw nice-looking ROC curves. Here's the code
#load libraries
library(randomForest)
library(pROC)
# generate some random data
set.seed(1111)
train <- data.frame(condition = sample(c("mock", "lethal", "resist"), replace = T, size = 1000))
train$feat01 <- sapply(train$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
train$feat02 <- sapply(train$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
train$feat03 <- sapply(train$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
head(train)
test <- data.frame(condition = sample(c("mock", "lethal", "resist"), replace = T, size = 1000))
test$feat01 <- sapply(test$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
test$feat02 <- sapply(test$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
test$feat03 <- sapply(test$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
head(test)
Now we have some data, let's train a Random Forest model as you did
# model
model <- randomForest(formula = condition ~ ., data = train, ntree = 10, maxnodes= 100, norm.votes = F)
Next, the model is used to predict the test data. However, you should ask for type="prob"
here.
# predict test set, get probs instead of response
predictions <- as.data.frame(predict(model, test, type = "prob"))
Since you have probabilities, use them to get the most-likely class.
# predict class and then attach test class
predictions$predict <- names(predictions)[1:3][apply(predictions[,1:3], 1, which.max)]
predictions$observed <- test$condition
head(predictions)
lethal mock resist predict observed
1 0.0 0.0 1.0 resist resist
2 0.0 0.6 0.4 mock mock
3 1.0 0.0 0.0 lethal mock
4 0.0 0.0 1.0 resist resist
5 0.0 1.0 0.0 mock mock
6 0.7 0.3 0.0 lethal mock
Now, let's see how to plot the ROC curves. For each class, convert the multi-class problem into a binary problem. Also, call the roc()
function specifying 2 arguments: i) observed classes and ii) class probability (instead of predicted class).
# 1 ROC curve, mock vs non mock
roc.mock <- roc(ifelse(predictions$observed=="mock", "mock", "non-mock"), as.numeric(predictions$mock))
plot(roc.mock, col = "gray60")
# others
roc.lethal <- roc(ifelse(predictions$observed=="lethal", "lethal", "non-lethal"), as.numeric(predictions$mock))
roc.resist <- roc(ifelse(predictions$observed=="resist", "resist", "non-resist"), as.numeric(predictions$mock))
lines(roc.lethal, col = "blue")
lines(roc.resist, col = "red")
Done. This is the result. Of course, the more observations in your test set, the smoother your curves will be.
Related videos on Youtube
Adam Price
Updated on September 13, 2020Comments
-
Adam Price over 3 years
I'm using the R package, randomForest, to create a model that classifies into three groups.
model = randomForest(formula = condition ~ ., data = train, ntree = 2000, mtry = bestm, importance = TRUE, proximity = TRUE) Type of random forest: classification Number of trees: 2000 No. of variables tried at each split: 3 OOB estimate of error rate: 5.71% Confusion matrix: lethal mock resistant class.error lethal 20 1 0 0.04761905 mock 1 37 0 0.02631579 resistant 2 0 9 0.18181818
I've tried with several libraries. For instance, with ROCR, you can't do three classifications, only two. Behold:
pred=prediction(predictions,train$condition) Error in prediction(predictions, train$condition) : Number of classes is not equal to 2. ROCR currently supports only evaluation of binary classification tasks.
data from model$votes is looking like this:
lethal mock resistant 3 0.04514364 0.952120383 0.002735978 89 0.32394366 0.147887324 0.528169014 16 0.02564103 0.973009447 0.001349528 110 0.55614973 0.433155080 0.010695187 59 0.06685633 0.903271693 0.029871977 43 0.13424658 0.865753425 0.000000000 41 0.82987552 0.033195021 0.136929461 86 0.32705249 0.468371467 0.204576043 87 0.37704918 0.341530055 0.281420765 ........
I can get some pretty ugly ROC plots this way using the pROC package:
predictions <- as.numeric(predict(model, test, type = 'response')) roc.multi <- multiclass.roc(test$condition, predictions, percent=TRUE) rs <- roc.multi[['rocs']] plot.roc(rs[[2]]) sapply(2:length(rs),function(i) lines.roc(rs[[i]],col=i))
There's no way to smooth those lines though, because they aren't so much of a curve as they are 4 or so points each.
I need a way to plot a nice smooth ROC curve for this model, but I can't seem to find one. Does anyone know of a good approach? Thanks very much in advance!
-
lebelinoz over 6 yearsCould you include the bare Sensitivity + Specificity data? Lots of folks on this site can help you create beautiful charts without understanding the nuances of random forecasts and ROC
-
Damiano Fantini over 6 yearsI think you can use either ROCR or pROC, but the best way to show this is to overimpose three lines corresponding to the three possible constrasts you have: Mock vs non-mock; lethal vs non-lethal; resist vs. non-resist. SO, you should convert your problem in 3 binary problems and plot the corresponding ROC curves...
-
-
user20650 over 6 yearsMy initial thought on this was also that ROC required binomial outcomes, but after searching web for
multiclass.roc
, many links came up - top lsearch stats.stackexchange.com/questions/2151/… . Although maybe this is what you are doing??? -
Damiano Fantini over 6 years@user20650 Good observation. To my understanding,
multiclass.roc()
only accepts one probability vector (predictor
argument) to rank the data. Random Forests assume no linearity in the response, and return n probability vectors (where n is the number of classes). Here, I am showing a way to deal with the problem by overposing three standard (binary) ROC analyses. -
Adam Price over 6 yearsThis is very helpful, I learned a lot from your post. Thank you!