Using randomForest package in R, how to get probabilities from classification model?
model$predicted
is NOT the same thing returned by predict()
. If you want the probability of the TRUE
or FALSE
class then you must run predict()
, or pass x,y,xtest,ytest
like
randomForest(x,y,xtest=x,ytest=y),
where x=out.data[, feature.cols], y=out.data[, response.col]
.
model$predicted
returns the class based on which class had the larger value in model$votes
for each record. votes
, as @joran pointed out is the proportion of OOB(out of bag) ‘votes’ from the random forest, a vote only counting when the record was selected in an OOB sample. On the other hand predict()
returns the true probability for each class based on votes by all the trees.
Using randomForest(x,y,xtest=x,ytest=y)
functions a little differently than when passing a formula or simply randomForest(x,y)
, as in the example given above. randomForest(x,y,xtest=x,ytest=y)
WILL return the probability for each class, this may sound a little weird, but it is found under model$test$votes
, and the predicted class under model$test$predicted
, which simply selects the class based on which class had the larger value in model$test$votes
. Also, when using randomForest(x,y,xtest=x,ytest=y)
, model$predicted
and model$votes
have the same definition as above.
Finally, just to note, if randomForest(x,y,xtest=x,ytest=y)
is used, then, in order to use predict() function the keep.forest flag should be set to TRUE.
model=randomForest(x,y,xtest=x,ytest=y,keep.forest=TRUE).
prob=predict(model,x,type="prob")
prob
WILL be equivalent to model$test$votes
since the test data input are both x
.
Related videos on Youtube
Mike Williamson
Slowly moving away from doing the fun technical work. :(
Updated on July 09, 2022Comments
-
Mike Williamson almost 2 years
TL;DR :
Is there something I can flag in the original
randomForest
call to avoid having to re-run thepredict
function to get predicted categorical probabilities, instead of just the likely category?Details:
I am using the randomForest package.
I have a model something like:
model <- randomForest(x=out.data[train.rows, feature.cols], y=out.data[train.rows, response.col], xtest=out.data[test.rows, feature.cols], ytest=out.data[test.rows, response.col], importance= TRUE)
where
out.data
is a data frame, withfeature.cols
a mixture of numeric and categorical features, whileresponse.col
is aTRUE
/FALSE
binary variable, that I forced intofactor
so thatrandomForest
model will properly treat it as categorical.All runs well, and the variable
model
is returned to me properly. However, I cannot seem to find a flag or parameter to pass to therandomForest
function so thatmodel
is returned to me with the probabilities ofTRUE
orFALSE
. Instead, I get simply predicted values. That is, if I look atmodel$predicted
, I'll see something like:FALSE FALSE TRUE TRUE FALSE . . .
Instead, I want to see something like:
FALSE TRUE 1 0.84 0.16 2 0.66 0.34 3 0.11 0.89 4 0.17 0.83 5 0.92 0.08 . . . . . . . . .
I can get the above, but in order to do so, I need to do something like:
tmp <- predict(model, out.data[test.rows, feature.cols], "prob")
[
test.rows
captures the row numbers for those that were used during the model testing. The details are not shown here, but are simple since the test row IDs are output intomodel
.]Then everything works fine. The problem is that the model is big and takes a very long time to run, and even the prediction itself takes a while. Since the prediction should be entirely unnecessary (I am simply looking to calculate the ROC curve on the test data set, the data set that should have already been calculated), I was hoping to skip this step. Is there something I can flag in the original
randomForest
call to avoid having to re-run thepredict
function?-
MrFlick over 9 yearsThe
randomForest
function can be used for any types of analysis; the question could benefit from a reproducible example that shows exactly what you are running with some sample/representative data. I would think if you just dopredict(model, type="prob")
it would be faster. Here, you want the prediction from the model you fit, so no need to pass in anewdata=
parameter. But since you didn't provide any way to test, it's hard to say if this will solve your problem. -
joran over 9 yearsSo you haven't noticed the votes component of the random forest object? There's a pretty clear description of it in the docs.
-
Mike Williamson over 9 yearsThanks, @joran ... I thought that "votes" might simply mean the probability. (E.g., if 300 / 500 trees that an obs. experienced voted "TRUE", then it would give 60% true.) However, that did not seem statistically "tight", in that IID is assumed by proxy. Since proximity and other data are available, I thought maybe more exacting probabilities could be extracted by adjusting the weights in some fashion. I presume this is not done. Thanks for confirmation!
-
joran over 9 yearsYour comment makes me think that you should maybe spend some time reading some references on random forests, particularly maybe Breiman's original paper (reference in the pckg docs). As the docse for
votes
states, the proportions are for OOB (out of bag) votes, so each case is only run down a tree for which it was not in the bootstrap sample. The are some other subtleties to how the OOB error rate is calculated (seeoob.times
), but what is invotes
is fairly rigorous... -
joran over 9 years...there are some critiques of OOB error rates, but again I would recommend reading up on the RF literature on that topic.
-
-
Mike Williamson over 9 yearsHi Oscar, I did provide & have been providing the "test" data set. Apologies I did not originally specify that... I have edited my original post. Thanks for specifying it is under "$test$votes"... that is precisely what I was looking for, although it still seems to be making lots of assumptions, like IID. (There is no covariance test or anything performed, as far as I can tell.) Thanks!
-
Oscar over 9 yearsHi Mike, I'm glad you got it. Don't forget to set the
keep.forest=TRUE
if you want to use thepredict()
function, just in case you want to pass other test data points. I don't think that there is a covariance test, but I have not looked into it so I'm not sure.