Basic text classification with Weka in Java

16,367

Solution 1

The Bayes classifier gives you a (weighted) probability that a word belongs to a category. This will almost never be exactly 0 or 1. You can either set a hard cutoff (e.g. 0.5) and decide membership for a class based on this, or inspect the calculated probabilities and decide based on that (i.e. the highest map to 1, the lowest to 0).

Solution 2

I thought i would just offer up that you could do most such text classification work with no coding by just downloading and using LightSIDE from http://lightsidelabs.com. This open source Java package includes WEKA, and is available for distributions on both Windows and Mac -- can can process most WEKA friendly data sets with great flexibility, allowing you to iterate through various models, settings and parameters and providing good support to snapshots and saving your data and models and classification results at any point until you have built a model you are happy with. This product proved itself in the ASAP competition on Kaggle.com last year, and is getting a lot of traction. Of course there are always reasons people want or need to "roll their own" but perhaps even as a check, knowing about and using LightSIDE if you are programming WEKA solutions could be very handy.

Share:
16,367
joxxe
Author by

joxxe

Updated on June 12, 2022

Comments

  • joxxe
    joxxe almost 2 years

    Im trying to build a text classifier in JAVA with Weka. I have read some tutorials, and I´m trying to build my own classifier.

    I have the following categories:

        computer,sport,unknown 
    

    and the following already trained data

     cs belongs to computer
     java -> computer
     soccer -> sport
     snowboard -> sport
    

    So for example, if a user wants to classify the word java, it should return the category computer (no doubt, java only exists in that category!).

    It does compile, but generates strange output.

    The output is:

          ====== RESULT ======  CLASSIFIED AS:  [0.5769230769230769, 0.2884615384615385, 0.1346153846153846]
          ====== RESULT ======  CLASSIFIED AS:  [0.42857142857142855, 0.42857142857142855, 0.14285714285714285]
    

    But the first text to classify is java and it occures only in the category computer, therefore it should be

          [1.0 0.0 0.0] 
    

    and for the other it shouldnt be found at all, so it should be classified as unknown

          [0.0 0.0 1.0].
    

    Here is the code:

        import java.io.FileNotFoundException;
        import java.io.Serializable;
        import java.util.Arrays;
    
        import weka.classifiers.Classifier;
        import weka.classifiers.bayes.NaiveBayesMultinomialUpdateable;
        import weka.core.Attribute;
        import weka.core.FastVector;
        import weka.core.Instance;
        import weka.core.Instances;
        import weka.filters.Filter;
        import weka.filters.unsupervised.attribute.StringToWordVector;
    
        public class TextClassifier implements Serializable {
    
            private static final long serialVersionUID = -1397598966481635120L;
            public static void main(String[] args) {
                try {
                    TextClassifier cl = new TextClassifier(new NaiveBayesMultinomialUpdateable());
                    cl.addCategory("computer");
                    cl.addCategory("sport");
                    cl.addCategory("unknown");
                    cl.setupAfterCategorysAdded();
    
                    //
                    cl.addData("cs", "computer");
                    cl.addData("java", "computer");
                    cl.addData("soccer", "sport");
                    cl.addData("snowboard", "sport");
    
                    double[] result = cl.classifyMessage("java");
                    System.out.println("====== RESULT ====== \tCLASSIFIED AS:\t" + Arrays.toString(result));
    
                    result = cl.classifyMessage("asdasdasd");
                    System.out.println("====== RESULT ======\tCLASSIFIED AS:\t" + Arrays.toString(result));
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
            private Instances trainingData;
            private StringToWordVector filter;
            private Classifier classifier;
            private boolean upToDate;
            private FastVector classValues;
            private FastVector attributes;
            private boolean setup;
    
            private Instances filteredData;
    
            public TextClassifier(Classifier classifier) throws FileNotFoundException {
                this(classifier, 10);
            }
    
            public TextClassifier(Classifier classifier, int startSize) throws FileNotFoundException {
                this.filter = new StringToWordVector();
                this.classifier = classifier;
                // Create vector of attributes.
                this.attributes = new FastVector(2);
                // Add attribute for holding texts.
                this.attributes.addElement(new Attribute("text", (FastVector) null));
                // Add class attribute.
                this.classValues = new FastVector(startSize);
                this.setup = false;
    
            }
    
            public void addCategory(String category) {
                category = category.toLowerCase();
                // if required, double the capacity.
                int capacity = classValues.capacity();
                if (classValues.size() > (capacity - 5)) {
                    classValues.setCapacity(capacity * 2);
                }
                classValues.addElement(category);
            }
    
            public void addData(String message, String classValue) throws IllegalStateException {
                if (!setup) {
                    throw new IllegalStateException("Must use setup first");
                }
                message = message.toLowerCase();
                classValue = classValue.toLowerCase();
                // Make message into instance.
                Instance instance = makeInstance(message, trainingData);
                // Set class value for instance.
                instance.setClassValue(classValue);
                // Add instance to training data.
                trainingData.add(instance);
                upToDate = false;
            }
    
            /**
             * Check whether classifier and filter are up to date. Build i necessary.
             * @throws Exception
             */
            private void buildIfNeeded() throws Exception {
                if (!upToDate) {
                    // Initialize filter and tell it about the input format.
                    filter.setInputFormat(trainingData);
                    // Generate word counts from the training data.
                    filteredData = Filter.useFilter(trainingData, filter);
                    // Rebuild classifier.
                    classifier.buildClassifier(filteredData);
                    upToDate = true;
                }
            }
    
            public double[] classifyMessage(String message) throws Exception {
                message = message.toLowerCase();
                if (!setup) {
                    throw new Exception("Must use setup first");
                }
                // Check whether classifier has been built.
                if (trainingData.numInstances() == 0) {
                    throw new Exception("No classifier available.");
                }
                buildIfNeeded();
                Instances testset = trainingData.stringFreeStructure();
                Instance testInstance = makeInstance(message, testset);
    
                // Filter instance.
                filter.input(testInstance);
                Instance filteredInstance = filter.output();
                return classifier.distributionForInstance(filteredInstance);
    
            }
    
            private Instance makeInstance(String text, Instances data) {
                // Create instance of length two.
                Instance instance = new Instance(2);
                // Set value for message attribute
                Attribute messageAtt = data.attribute("text");
                instance.setValue(messageAtt, messageAtt.addStringValue(text));
                // Give instance access to attribute information from the dataset.
                instance.setDataset(data);
                return instance;
            }
    
            public void setupAfterCategorysAdded() {
                attributes.addElement(new Attribute("class", classValues));
                // Create dataset with initial capacity of 100, and set index of class.
                trainingData = new Instances("MessageClassificationProblem", attributes, 100);
                trainingData.setClassIndex(trainingData.numAttributes() - 1);
                setup = true;
            }
    
        }
    

    Btw, found a good page:

    http://www.hakank.org/weka/TextClassifierApplet3.html

  • joxxe
    joxxe about 12 years
    Yes I know that. But in this example, when I try to classify: result = cl.classifyMessage("asdasdasd"); the result should be classified as unknown, but it is not :/ But, I can see know that it probably wont work. Because i dont have any documents at all for that category... Is there any smart solution for adding a "uknown" category or similar? And also for the word java, I thought it would be weighted more towards the computer category, because it is not even mentioned in the other categories.
  • Lars Kotthoff
    Lars Kotthoff about 12 years
    You would have to manually add things to an "unknown" category if the probability for the membership of each of your classes is too low.
  • joxxe
    joxxe about 12 years
    But the probabilities together is always 1.0. So if I try to classify some word that does not exist in any document, the probabilities together (for all categories) is still 1.0
  • Lars Kotthoff
    Lars Kotthoff about 12 years
    They can still all be below 0.5, for example. You'll have to make the decision which probability to choose to mean "belongs to this category" anyway.
  • joxxe
    joxxe about 12 years
    Lars: Thats true. Could be one solution. Another solution could be to check the already trained documents for this word, and if it not exists in any document I could skip the actual classifier and just return "unknown"
  • ronalchn
    ronalchn over 11 years
    this answer is difficult to understand - when you post code, you should try to explain it
  • Admin
    Admin over 11 years
    Well if joxxe try this out in his code , he will understand what I mean. Anyway, the distributionForInstance() method returns probabilities for each target classes. And classifyInstance() method pick the class with the highest probability and return the class index. So I guess joxxe need the second method.
  • dzieciou
    dzieciou about 11 years
    What if this is non-binary problem and you have more than two categories? Will it mean each category has its value and I need to find one with shortest distance to the output of a classifier?