How can I draw a ROC curve for a randomForest model with three classes in R?

21,351

I see two problems here 1) ROC curves work for binary classifiers, so you should convert your performance evaluation in a series of binary problems. I show below how to do this. 2) You should get the probabilities of each observation to belong to each of your classes (rather than just the predicted class) when you predict your test set. This will allow you to draw nice-looking ROC curves. Here's the code

#load libraries
library(randomForest)
library(pROC)

# generate some random data
set.seed(1111)
train <- data.frame(condition = sample(c("mock", "lethal", "resist"), replace = T, size = 1000))
train$feat01 <- sapply(train$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
train$feat02 <- sapply(train$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
train$feat03 <- sapply(train$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
head(train)

test <- data.frame(condition = sample(c("mock", "lethal", "resist"), replace = T, size = 1000))
test$feat01 <- sapply(test$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
test$feat02 <- sapply(test$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
test$feat03 <- sapply(test$condition, (function(i){ if (i == "mock") { rnorm(n = 1, mean = 0)} else if (i == "lethal") { rnorm(n = 1, mean = 1.5)} else { rnorm(n = 1, mean = -1.5)} }))
head(test)

Now we have some data, let's train a Random Forest model as you did

# model
model <- randomForest(formula = condition ~ ., data = train, ntree = 10, maxnodes= 100, norm.votes = F) 

Next, the model is used to predict the test data. However, you should ask for type="prob" here.

# predict test set, get probs instead of response
predictions <- as.data.frame(predict(model, test, type = "prob"))

Since you have probabilities, use them to get the most-likely class.

# predict class and then attach test class
predictions$predict <- names(predictions)[1:3][apply(predictions[,1:3], 1, which.max)]
predictions$observed <- test$condition
head(predictions)
  lethal mock resist predict observed
1    0.0  0.0    1.0  resist   resist
2    0.0  0.6    0.4    mock     mock
3    1.0  0.0    0.0  lethal     mock
4    0.0  0.0    1.0  resist   resist
5    0.0  1.0    0.0    mock     mock
6    0.7  0.3    0.0  lethal     mock

Now, let's see how to plot the ROC curves. For each class, convert the multi-class problem into a binary problem. Also, call the roc() function specifying 2 arguments: i) observed classes and ii) class probability (instead of predicted class).

# 1 ROC curve, mock vs non mock
roc.mock <- roc(ifelse(predictions$observed=="mock", "mock", "non-mock"), as.numeric(predictions$mock))
plot(roc.mock, col = "gray60")

# others
roc.lethal <- roc(ifelse(predictions$observed=="lethal", "lethal", "non-lethal"), as.numeric(predictions$mock))
roc.resist <- roc(ifelse(predictions$observed=="resist", "resist", "non-resist"), as.numeric(predictions$mock))
lines(roc.lethal, col = "blue")
lines(roc.resist, col = "red")

Done. This is the result. Of course, the more observations in your test set, the smoother your curves will be.

enter image description here

Share:
21,351

Related videos on Youtube

Adam Price
Author by

Adam Price

Updated on September 13, 2020

Comments

  • Adam Price
    Adam Price over 3 years

    I'm using the R package, randomForest, to create a model that classifies into three groups.

     model = randomForest(formula = condition ~ ., data = train, ntree = 2000,      
                           mtry = bestm, importance = TRUE, proximity = TRUE) 
    
               Type of random forest: classification
                     Number of trees: 2000
                     No. of variables tried at each split: 3
    
               OOB estimate of  error rate: 5.71%
    
               Confusion matrix:
               lethal mock resistant class.error
     lethal        20    1         0  0.04761905
     mock           1   37         0  0.02631579
     resistant      2    0         9  0.18181818
    

    I've tried with several libraries. For instance, with ROCR, you can't do three classifications, only two. Behold:

    pred=prediction(predictions,train$condition)
    
    Error in prediction(predictions, train$condition) : 
      Number of classes is not equal to 2.
      ROCR currently supports only evaluation of binary classification 
      tasks.
    

    data from model$votes is looking like this:

             lethal        mock   resistant
     3   0.04514364 0.952120383 0.002735978
     89  0.32394366 0.147887324 0.528169014
     16  0.02564103 0.973009447 0.001349528
     110 0.55614973 0.433155080 0.010695187
     59  0.06685633 0.903271693 0.029871977
     43  0.13424658 0.865753425 0.000000000
     41  0.82987552 0.033195021 0.136929461
     86  0.32705249 0.468371467 0.204576043
     87  0.37704918 0.341530055 0.281420765
     ........
    

    I can get some pretty ugly ROC plots this way using the pROC package:

    predictions <- as.numeric(predict(model, test, type = 'response'))
    roc.multi <- multiclass.roc(test$condition, predictions, 
                                percent=TRUE)
    rs <- roc.multi[['rocs']]
    plot.roc(rs[[2]])
    sapply(2:length(rs),function(i) lines.roc(rs[[i]],col=i))
    

    Those plots look like this: Figure 1: Ugly ROC curve

    There's no way to smooth those lines though, because they aren't so much of a curve as they are 4 or so points each.

    I need a way to plot a nice smooth ROC curve for this model, but I can't seem to find one. Does anyone know of a good approach? Thanks very much in advance!

    • lebelinoz
      lebelinoz over 6 years
      Could you include the bare Sensitivity + Specificity data? Lots of folks on this site can help you create beautiful charts without understanding the nuances of random forecasts and ROC
    • Damiano Fantini
      Damiano Fantini over 6 years
      I think you can use either ROCR or pROC, but the best way to show this is to overimpose three lines corresponding to the three possible constrasts you have: Mock vs non-mock; lethal vs non-lethal; resist vs. non-resist. SO, you should convert your problem in 3 binary problems and plot the corresponding ROC curves...
  • user20650
    user20650 over 6 years
    My initial thought on this was also that ROC required binomial outcomes, but after searching web for multiclass.roc, many links came up - top lsearch stats.stackexchange.com/questions/2151/… . Although maybe this is what you are doing???
  • Damiano Fantini
    Damiano Fantini over 6 years
    @user20650 Good observation. To my understanding, multiclass.roc() only accepts one probability vector (predictor argument) to rank the data. Random Forests assume no linearity in the response, and return n probability vectors (where n is the number of classes). Here, I am showing a way to deal with the problem by overposing three standard (binary) ROC analyses.
  • Adam Price
    Adam Price over 6 years
    This is very helpful, I learned a lot from your post. Thank you!