Multinomial classification using neuralnet package

10,549

You are right that the formula interface of neuralnet() does not support '.'.

However, the problem with your code above is rather that a factor is not accepted as target. You have to expand the factor Species to three binary variables first. Ironically, this works best with the function class.ind() from the nnet package (which wouldn't need such a function, since nnet() and multinom() work fine with factors):

trainData <- cbind(iris[, 1:4], class.ind(iris$Species))
neuralnet(setosa + versicolor + virginica ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, trainData)

This works - at least for me.

Share:
10,549

Related videos on Youtube

Ricardo Magalhães Cruz
Author by

Ricardo Magalhães Cruz

MS in Applied Mathematics BS in Computer Science

Updated on July 04, 2022

Comments

  • Ricardo Magalhães Cruz
    Ricardo Magalhães Cruz almost 2 years

    This question ought to be real simple. But the documentation isn't helping.

    I am using R. I must use the neuralnet package for a multinomial classification problem.

    All examples are for binomial or linear output. I could do some one-vs-all implementation using binomial output. But I believe I should be able to do this by having 3 units as the output layer, where each is a binomial (ie. probability of that being the correct output). No?

    This is what I would using nnet (which I believe is doing what I want):

    data(iris)
    library(nnet)
    m1 <- nnet(Species ~ ., iris, size = 3)
    table(predict(m1, iris, type = "class"), iris$Species)
    

    This is what I am trying to do using neuralnet (the formula hack is because neuralnet does not seem to support the '.' notation in the formula):

    data(iris)
    library(neuralnet)
    formula <- paste('Species ~', paste(names(iris)[-length(iris)], collapse='+'))
    m2 <- neuralnet(formula, iris, hidden=3, linear.output=FALSE)
    # fails !
    
    • UBod
      UBod
      It seems that the 'neuralnet' package has been updated and supports this functionality now (since version 1.44.1). So the above code no longer fails, and also dots in formulas are not supported.