Trie-based dictionary

I am working on a spelling game, so I need fast dictionary lookup and also to find words from random sets. I’m using the official Scrabble SOWPODS word list.

EDIT: Here is the code incorporating most of the suggestions in the thread, many thanks to everyone. If you want to read the backstory of the code, just go to ENDEDIT and read on :slight_smile:


package com.clarke.agnes.verbal;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;

/**
 * Parses a dictionary into a trie.
 * Created by agslinda on 7/2/14.
 */
public class TrieDictionary {

    private static final int MAX_WORD_LENGTH = 9;
    ArrayTrie arrayTrie = new ArrayTrie();
    private ArrayList<String> longestWords = new ArrayList<String>(45000);
    private ArrayList<String> words = new ArrayList<String>();

    public TrieDictionary() throws IOException {
        StringBuilder sb = new StringBuilder();
        FileReader reader = new FileReader(new File("...sowpods.txt"));
        int character = reader.read();
        while (character >= 0) {
            //noinspection StatementWithEmptyBody
            if (character == '\r') {
                //do nothing
            } else if (character == '\n') {
                String word = sb.toString();
                if (sb.length() <= MAX_WORD_LENGTH) {
                    arrayTrie.addWord(word);
                    words.add(word);
                }
                if (sb.length() == MAX_WORD_LENGTH) {
                    longestWords.add(word);
                }
                sb.delete(0, sb.length());
            } else {
                sb.append((char)character);
            }
            character = reader.read();
        }
        reader.close();
    }

    private final Random random = new Random();
    public String scrambleWord() {
        TreeMap<Double, Character> map = new TreeMap<Double, Character>();
        String word = longestWords.get(random.nextInt(longestWords.size()));
        while (map.size() != word.length()) {
            for (char c : word.toCharArray()) {
                map.put(random.nextDouble(), c);
            }
            if (map.size() != word.length()) {
                map.clear();
            }
        }
        StringBuilder sb = new StringBuilder();
        for (Character character : map.values()) {
            sb.append(character);
        }
        return sb.toString();
    }

    public boolean exists(String candidateWord) {
        return arrayTrie.wordExists(candidateWord);
    }

    public List<String> unscramble(String scrambledWord) {
        TreeMap<String, Object> results = new TreeMap<String, Object>();
        ArrayList<Character> scrambled = new ArrayList<Character>();
        for (char c : scrambledWord.toCharArray()) {
            scrambled.add(c);
        }
        for (char character : scrambledWord.toCharArray()) {
            CharacterNode node = arrayTrie.get(character);
            if (node != null) {
                int index = scrambled.indexOf(character);
                scrambled.remove(Character.valueOf(character));
                node.unscramble(results, scrambled);
                scrambled.add(index, character);
            }
        }
        return new ArrayList<String>(results.keySet());
    }


    public static void main(String[] args) throws IOException {
        TrieDictionary dict = new TrieDictionary();

        int x = 0, spell, misspell;
        long start = System.currentTimeMillis();
        spell = 0; misspell = 0;
        for (int i = 0; i < dict.words.size(); i += 1) {
            String word = dict.words.get(i);
            x += dict.exists(word) ? spell++ : misspell++;
            x += dict.exists('n' + word) ? spell++ : misspell++;
            x += dict.exists(word + 'n') ? spell++ : misspell++;
        }
        System.out.println("Trie took " + (System.currentTimeMillis() - start) + " right:wrong " + spell + ":" + misspell);
        System.out.println("Trie nodes: " + Node.instanceCount + " Arrays: " + Node.arrayCount);
        System.out.println("Words: " + dict.words.size() + " LongestWords: " + dict.longestWords.size());
        String s = dict.scrambleWord();
        start = System.currentTimeMillis();
        List<String> unscramble = dict.unscramble(s);
        System.out.println("Unscramble took " + (System.currentTimeMillis() - start));
        System.out.println(x + "\r" + "Word of the day: " + s + " has " + unscramble.size() + " solutions");
        int counter = 0;
        for (String s1 : unscramble) {
            if (s1.length() == MAX_WORD_LENGTH) {
                System.out.print("###");
            }
            if (s1.length() >= MAX_WORD_LENGTH / 2) {
                System.out.print(s1 + "   ");
                if (counter++ == 5) {
                    counter = 0;
                    System.out.println();
                }
            }
        }
    }

    private static class ArrayTrie {
        Node root = new Node();
        private void addWord(String word) {
            Node n = root;
            for (char c : word.toCharArray()) {
                n = n.put(c);
            }
            ((CharacterNode)n).completeWord = true;
        }

        private boolean wordExists(String candidateWord) {
            Node n = root;
            char[] chars = candidateWord.toCharArray();
            for (int i = 0; i < chars.length; i++) {
                n = n.get(chars[i]);
                if (n == null) {
                    break;
                }
            }
            return (n != null) && ((CharacterNode) n).completeWord;
        }

        public CharacterNode get(char c) {
            return root.get(c);
        }
    }

    private static class Node {
        protected static int instanceCount = 0;
        protected static int arrayCount = 0;
        protected Node parent = null;
        protected char c;
        private static final int ALPHABET_LENGTH = 26;
        protected CharacterNode[] children = null;

        public CharacterNode put(char c) {
            CharacterNode n = get(c);
            if (children == null) {
                arrayCount++;
                children = new CharacterNode[ALPHABET_LENGTH];
            }
            if (n == null) {
                n = new CharacterNode(this, c);
                children[c % ALPHABET_LENGTH] = n;
            }
            return n;
        }

        public CharacterNode get(char c) {
            return children == null ? null : children[c % ALPHABET_LENGTH];
        }

    }

    private static class CharacterNode extends Node {
        private boolean completeWord = false;
        private static StringBuilder sb = new StringBuilder();
        public CharacterNode(Node parent, char c) {
            this.parent = parent;
            this.c = c;
            instanceCount++;
        }
        public void unscramble(Map<String, Object> results, ArrayList<Character> scrambled) {
            if (completeWord) {
                results.put(outputWord(), null);
            }
            for (int i = 0; i < scrambled.size(); i++) {
                Character character = scrambled.get(i);
                CharacterNode node = get(character);
                if (node != null) {
                    int index = scrambled.indexOf(character);
                    scrambled.remove(character);
                    node.unscramble(results, scrambled);
                    scrambled.add(index, character);
                }
            }
        }

        private String outputWord() {
            sb.append(c);
            Node p = parent;
            while (p != null && p.parent != null) {
                sb.append(p.c);
                p = p.parent;
            }
            sb.reverse();
            String result = sb.toString();
            sb.delete(0, sb.length());
            return result;
        }
    }
}

ENDEDIT

I have three spell check methods implemented for comparison:

  1. Naive list iteration (at least I stop iterating if we go past where the misspelled word would have been in the dictionary)
  2. Binary search (quick and dirty, i just added it to show how fast Trie is)
  3. Trie search (prefix tree search). This is supposed to be the fastest.

Anyway, binary won by a couple of millis… :frowning:

Trie took 46 right:wrong 7263:14448
Binary took 33 right:wrong 7263:14448
List took 98224 right:wrong 7263:14448

I can get Trie to win by a a margin of 2-3 millis if I use HashMaps and not ArrayLists for the Trie child node lists. But I think that will take too much memory on Android.


package com.clarke.agnes.verbal;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * Parses a dictionary into a trie.
 * Created by agslinda on 7/2/14.
 */
public class TrieDictionary {

    Trie trie = new Trie();
    private ArrayList<String> words = new ArrayList<String>(280000);

    public TrieDictionary() throws IOException {
        StringBuilder sb = new StringBuilder();
        FileReader reader = new FileReader(new File("...sowpods.txt"));
        int character = reader.read();
        while (character >= 0) {
            //noinspection StatementWithEmptyBody
            if (character == '\r') {
                //do nothing
            } else if (character == '\n') {
                trie.addWord(sb.toString());
                words.add(sb.toString());
                sb.delete(0, sb.length());
            } else {
                sb.append((char)character);
            }
            character = reader.read();
        }
        reader.close();
    }

    public boolean exists(String candidateWord) {
        return trie.wordExists(candidateWord);
    }

    public boolean existsOld(String candidateWord) {
        for (int i = 0; i < words.size(); i++) {
            if (candidateWord.equals(words.get(i))) {
                return true;
            } else if (candidateWord.compareTo(words.get(i)) < 0) {
                return false;
            }
        }
        return false;
    }

    public boolean existsBinary(String candidateWord) {
        int index = words.size() / 2;
        int upper = words.size();
        int lower = 0;
        int breaker = 0;
        String wordFromDict = words.get(index);
        while (!wordFromDict.equals(candidateWord) && breaker++ < 20) {
            if (candidateWord.compareTo(wordFromDict) > 0) {
                lower = index;
                index = (index + upper) / 2;
            } else {
                upper = index;
                index = (index + lower) / 2;
            }
            wordFromDict = words.get(index);
        }

        return wordFromDict.equals(candidateWord);
    }


    public static void main(String[] args) throws IOException {
        TrieDictionary dict = new TrieDictionary();
        System.out.println("able " + dict.exists("able") + dict.existsBinary("able") + dict.existsOld("able"));
        System.out.println("zzz " + dict.exists("zzz") + dict.existsBinary("zzz") + dict.existsOld("zzz"));
        System.out.println("hjkglfy " + dict.exists("hjkglfy") + dict.existsBinary("hjkglfy") + dict.existsOld("hjkglfy"));
        System.out.println("true " + dict.exists("true") + dict.existsBinary("true") + dict.existsOld("true"));
        System.out.println("abstemious " + dict.exists("abstemious") + dict.existsBinary("abstemious") + dict.existsOld("abstemious"));

        int spell = 0, misspell = 0;
        long start = System.currentTimeMillis();
        for (int i = 0; i < dict.words.size(); i += 37) {
            String word = dict.words.get(i);
            int x = dict.exists(word) ? spell++ : misspell++;
            x = dict.exists('n' + word) ? spell++ : misspell++;
            x = dict.exists(word + 'n') ? spell++ : misspell++;
        }
        System.out.println("Trie took " + (System.currentTimeMillis() - start) + " right:wrong " + spell + ":" + misspell);

        start = System.currentTimeMillis();
        spell = 0; misspell = 0;
        for (int i = 0; i < dict.words.size(); i += 37) {
            String word = dict.words.get(i);
            int x = dict.existsBinary(word) ? spell++ : misspell++;
            x = dict.existsBinary('n' + word) ? spell++ : misspell++;
            x = dict.existsBinary(word + 'n') ? spell++ : misspell++;
        }
        System.out.println("Binary took " + (System.currentTimeMillis() - start) + " right:wrong " + spell + ":" + misspell);

        start = System.currentTimeMillis();
        spell = 0; misspell = 0;
        for (int i = 0; i < dict.words.size(); i += 37) {
            String word = dict.words.get(i);
            int x = dict.existsOld(word) ? spell++ : misspell++;
            x = dict.existsOld('n' + word) ? spell++ : misspell++;
            x = dict.existsOld(word + 'n') ? spell++ : misspell++;
        }
        System.out.println("List took " + (System.currentTimeMillis() - start) + " right:wrong " + spell + ":" + misspell);

    }

    private static class Trie {
        Node root = new Node();
        private void addWord(String word) {
            Node n = root;
            for (char c : word.toCharArray()) {
                n = n.put(c);
            }
            ((CharacterNode)n).completeWord = true;
        }

        private boolean wordExists(String candidateWord) {
            Node n = root;
            for (char c : candidateWord.toCharArray()) {
                n = n.get(c);
                if (n == null) {
                    break;
                }
            }
            return (n != null) && ((CharacterNode) n).completeWord;
        }
    }

    private static class Node {
        protected static final List<CharacterNode> empty = Collections.emptyList();
        protected List<CharacterNode> children = empty;

        public Node put(char c) {
            for (CharacterNode n : children) {
                if (n.value == c) {
                    return n;
                }
            }
            CharacterNode n = new CharacterNode(c);
            if (children.equals(empty)) {
                children = new ArrayList<CharacterNode>();
            }
            children.add(n);
            return n;
        }

        public CharacterNode get(char c) {
            for (CharacterNode n : children) {
                if (n.value == c) {
                    return n;
                }
            }
            return null;
        }
    }

    private static class CharacterNode extends Node {
        private boolean completeWord = false;
        private final char value;

        private CharacterNode(char value) {
            this.value = value;
        }

    }

}