Trie-based dictionary

I am working on a spelling game, so I need fast dictionary lookup and also to find words from random sets. I’m using the official Scrabble SOWPODS word list.

EDIT: Here is the code incorporating most of the suggestions in the thread, many thanks to everyone. If you want to read the backstory of the code, just go to ENDEDIT and read on :slight_smile:


package com.clarke.agnes.verbal;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;

/**
 * Parses a dictionary into a trie.
 * Created by agslinda on 7/2/14.
 */
public class TrieDictionary {

    private static final int MAX_WORD_LENGTH = 9;
    ArrayTrie arrayTrie = new ArrayTrie();
    private ArrayList<String> longestWords = new ArrayList<String>(45000);
    private ArrayList<String> words = new ArrayList<String>();

    public TrieDictionary() throws IOException {
        StringBuilder sb = new StringBuilder();
        FileReader reader = new FileReader(new File("...sowpods.txt"));
        int character = reader.read();
        while (character >= 0) {
            //noinspection StatementWithEmptyBody
            if (character == '\r') {
                //do nothing
            } else if (character == '\n') {
                String word = sb.toString();
                if (sb.length() <= MAX_WORD_LENGTH) {
                    arrayTrie.addWord(word);
                    words.add(word);
                }
                if (sb.length() == MAX_WORD_LENGTH) {
                    longestWords.add(word);
                }
                sb.delete(0, sb.length());
            } else {
                sb.append((char)character);
            }
            character = reader.read();
        }
        reader.close();
    }

    private final Random random = new Random();
    public String scrambleWord() {
        TreeMap<Double, Character> map = new TreeMap<Double, Character>();
        String word = longestWords.get(random.nextInt(longestWords.size()));
        while (map.size() != word.length()) {
            for (char c : word.toCharArray()) {
                map.put(random.nextDouble(), c);
            }
            if (map.size() != word.length()) {
                map.clear();
            }
        }
        StringBuilder sb = new StringBuilder();
        for (Character character : map.values()) {
            sb.append(character);
        }
        return sb.toString();
    }

    public boolean exists(String candidateWord) {
        return arrayTrie.wordExists(candidateWord);
    }

    public List<String> unscramble(String scrambledWord) {
        TreeMap<String, Object> results = new TreeMap<String, Object>();
        ArrayList<Character> scrambled = new ArrayList<Character>();
        for (char c : scrambledWord.toCharArray()) {
            scrambled.add(c);
        }
        for (char character : scrambledWord.toCharArray()) {
            CharacterNode node = arrayTrie.get(character);
            if (node != null) {
                int index = scrambled.indexOf(character);
                scrambled.remove(Character.valueOf(character));
                node.unscramble(results, scrambled);
                scrambled.add(index, character);
            }
        }
        return new ArrayList<String>(results.keySet());
    }


    public static void main(String[] args) throws IOException {
        TrieDictionary dict = new TrieDictionary();

        int x = 0, spell, misspell;
        long start = System.currentTimeMillis();
        spell = 0; misspell = 0;
        for (int i = 0; i < dict.words.size(); i += 1) {
            String word = dict.words.get(i);
            x += dict.exists(word) ? spell++ : misspell++;
            x += dict.exists('n' + word) ? spell++ : misspell++;
            x += dict.exists(word + 'n') ? spell++ : misspell++;
        }
        System.out.println("Trie took " + (System.currentTimeMillis() - start) + " right:wrong " + spell + ":" + misspell);
        System.out.println("Trie nodes: " + Node.instanceCount + " Arrays: " + Node.arrayCount);
        System.out.println("Words: " + dict.words.size() + " LongestWords: " + dict.longestWords.size());
        String s = dict.scrambleWord();
        start = System.currentTimeMillis();
        List<String> unscramble = dict.unscramble(s);
        System.out.println("Unscramble took " + (System.currentTimeMillis() - start));
        System.out.println(x + "\r" + "Word of the day: " + s + " has " + unscramble.size() + " solutions");
        int counter = 0;
        for (String s1 : unscramble) {
            if (s1.length() == MAX_WORD_LENGTH) {
                System.out.print("###");
            }
            if (s1.length() >= MAX_WORD_LENGTH / 2) {
                System.out.print(s1 + "   ");
                if (counter++ == 5) {
                    counter = 0;
                    System.out.println();
                }
            }
        }
    }

    private static class ArrayTrie {
        Node root = new Node();
        private void addWord(String word) {
            Node n = root;
            for (char c : word.toCharArray()) {
                n = n.put(c);
            }
            ((CharacterNode)n).completeWord = true;
        }

        private boolean wordExists(String candidateWord) {
            Node n = root;
            char[] chars = candidateWord.toCharArray();
            for (int i = 0; i < chars.length; i++) {
                n = n.get(chars[i]);
                if (n == null) {
                    break;
                }
            }
            return (n != null) && ((CharacterNode) n).completeWord;
        }

        public CharacterNode get(char c) {
            return root.get(c);
        }
    }

    private static class Node {
        protected static int instanceCount = 0;
        protected static int arrayCount = 0;
        protected Node parent = null;
        protected char c;
        private static final int ALPHABET_LENGTH = 26;
        protected CharacterNode[] children = null;

        public CharacterNode put(char c) {
            CharacterNode n = get(c);
            if (children == null) {
                arrayCount++;
                children = new CharacterNode[ALPHABET_LENGTH];
            }
            if (n == null) {
                n = new CharacterNode(this, c);
                children[c % ALPHABET_LENGTH] = n;
            }
            return n;
        }

        public CharacterNode get(char c) {
            return children == null ? null : children[c % ALPHABET_LENGTH];
        }

    }

    private static class CharacterNode extends Node {
        private boolean completeWord = false;
        private static StringBuilder sb = new StringBuilder();
        public CharacterNode(Node parent, char c) {
            this.parent = parent;
            this.c = c;
            instanceCount++;
        }
        public void unscramble(Map<String, Object> results, ArrayList<Character> scrambled) {
            if (completeWord) {
                results.put(outputWord(), null);
            }
            for (int i = 0; i < scrambled.size(); i++) {
                Character character = scrambled.get(i);
                CharacterNode node = get(character);
                if (node != null) {
                    int index = scrambled.indexOf(character);
                    scrambled.remove(character);
                    node.unscramble(results, scrambled);
                    scrambled.add(index, character);
                }
            }
        }

        private String outputWord() {
            sb.append(c);
            Node p = parent;
            while (p != null && p.parent != null) {
                sb.append(p.c);
                p = p.parent;
            }
            sb.reverse();
            String result = sb.toString();
            sb.delete(0, sb.length());
            return result;
        }
    }
}

ENDEDIT

I have three spell check methods implemented for comparison:

  1. Naive list iteration (at least I stop iterating if we go past where the misspelled word would have been in the dictionary)
  2. Binary search (quick and dirty, i just added it to show how fast Trie is)
  3. Trie search (prefix tree search). This is supposed to be the fastest.

Anyway, binary won by a couple of millis… :frowning:

Trie took 46 right:wrong 7263:14448
Binary took 33 right:wrong 7263:14448
List took 98224 right:wrong 7263:14448

I can get Trie to win by a a margin of 2-3 millis if I use HashMaps and not ArrayLists for the Trie child node lists. But I think that will take too much memory on Android.


package com.clarke.agnes.verbal;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * Parses a dictionary into a trie.
 * Created by agslinda on 7/2/14.
 */
public class TrieDictionary {

    Trie trie = new Trie();
    private ArrayList<String> words = new ArrayList<String>(280000);

    public TrieDictionary() throws IOException {
        StringBuilder sb = new StringBuilder();
        FileReader reader = new FileReader(new File("...sowpods.txt"));
        int character = reader.read();
        while (character >= 0) {
            //noinspection StatementWithEmptyBody
            if (character == '\r') {
                //do nothing
            } else if (character == '\n') {
                trie.addWord(sb.toString());
                words.add(sb.toString());
                sb.delete(0, sb.length());
            } else {
                sb.append((char)character);
            }
            character = reader.read();
        }
        reader.close();
    }

    public boolean exists(String candidateWord) {
        return trie.wordExists(candidateWord);
    }

    public boolean existsOld(String candidateWord) {
        for (int i = 0; i < words.size(); i++) {
            if (candidateWord.equals(words.get(i))) {
                return true;
            } else if (candidateWord.compareTo(words.get(i)) < 0) {
                return false;
            }
        }
        return false;
    }

    public boolean existsBinary(String candidateWord) {
        int index = words.size() / 2;
        int upper = words.size();
        int lower = 0;
        int breaker = 0;
        String wordFromDict = words.get(index);
        while (!wordFromDict.equals(candidateWord) && breaker++ < 20) {
            if (candidateWord.compareTo(wordFromDict) > 0) {
                lower = index;
                index = (index + upper) / 2;
            } else {
                upper = index;
                index = (index + lower) / 2;
            }
            wordFromDict = words.get(index);
        }

        return wordFromDict.equals(candidateWord);
    }


    public static void main(String[] args) throws IOException {
        TrieDictionary dict = new TrieDictionary();
        System.out.println("able " + dict.exists("able") + dict.existsBinary("able") + dict.existsOld("able"));
        System.out.println("zzz " + dict.exists("zzz") + dict.existsBinary("zzz") + dict.existsOld("zzz"));
        System.out.println("hjkglfy " + dict.exists("hjkglfy") + dict.existsBinary("hjkglfy") + dict.existsOld("hjkglfy"));
        System.out.println("true " + dict.exists("true") + dict.existsBinary("true") + dict.existsOld("true"));
        System.out.println("abstemious " + dict.exists("abstemious") + dict.existsBinary("abstemious") + dict.existsOld("abstemious"));

        int spell = 0, misspell = 0;
        long start = System.currentTimeMillis();
        for (int i = 0; i < dict.words.size(); i += 37) {
            String word = dict.words.get(i);
            int x = dict.exists(word) ? spell++ : misspell++;
            x = dict.exists('n' + word) ? spell++ : misspell++;
            x = dict.exists(word + 'n') ? spell++ : misspell++;
        }
        System.out.println("Trie took " + (System.currentTimeMillis() - start) + " right:wrong " + spell + ":" + misspell);

        start = System.currentTimeMillis();
        spell = 0; misspell = 0;
        for (int i = 0; i < dict.words.size(); i += 37) {
            String word = dict.words.get(i);
            int x = dict.existsBinary(word) ? spell++ : misspell++;
            x = dict.existsBinary('n' + word) ? spell++ : misspell++;
            x = dict.existsBinary(word + 'n') ? spell++ : misspell++;
        }
        System.out.println("Binary took " + (System.currentTimeMillis() - start) + " right:wrong " + spell + ":" + misspell);

        start = System.currentTimeMillis();
        spell = 0; misspell = 0;
        for (int i = 0; i < dict.words.size(); i += 37) {
            String word = dict.words.get(i);
            int x = dict.existsOld(word) ? spell++ : misspell++;
            x = dict.existsOld('n' + word) ? spell++ : misspell++;
            x = dict.existsOld(word + 'n') ? spell++ : misspell++;
        }
        System.out.println("List took " + (System.currentTimeMillis() - start) + " right:wrong " + spell + ":" + misspell);

    }

    private static class Trie {
        Node root = new Node();
        private void addWord(String word) {
            Node n = root;
            for (char c : word.toCharArray()) {
                n = n.put(c);
            }
            ((CharacterNode)n).completeWord = true;
        }

        private boolean wordExists(String candidateWord) {
            Node n = root;
            for (char c : candidateWord.toCharArray()) {
                n = n.get(c);
                if (n == null) {
                    break;
                }
            }
            return (n != null) && ((CharacterNode) n).completeWord;
        }
    }

    private static class Node {
        protected static final List<CharacterNode> empty = Collections.emptyList();
        protected List<CharacterNode> children = empty;

        public Node put(char c) {
            for (CharacterNode n : children) {
                if (n.value == c) {
                    return n;
                }
            }
            CharacterNode n = new CharacterNode(c);
            if (children.equals(empty)) {
                children = new ArrayList<CharacterNode>();
            }
            children.add(n);
            return n;
        }

        public CharacterNode get(char c) {
            for (CharacterNode n : children) {
                if (n.value == c) {
                    return n;
                }
            }
            return null;
        }
    }

    private static class CharacterNode extends Node {
        private boolean completeWord = false;
        private final char value;

        private CharacterNode(char value) {
            this.value = value;
        }

    }

}


Here is the faster Node class, with it the trie completes the test in 30ms on my system, instead of approx ~44ms, and consistently beat binary search!


    private static class Node {
        protected static final Map<Character, CharacterNode> emptyMap = Collections.<Character,CharacterNode>emptyMap();
        protected Map<Character, CharacterNode> childrenMap = emptyMap;

        public Node put(char c) {
            CharacterNode n = childrenMap.get(c);
            if (childrenMap.equals(emptyMap)) {
                childrenMap = new HashMap<Character,CharacterNode>();
            }
            if (n == null) {
                n = new CharacterNode(c);
                childrenMap.put(c, n);
            }
            return n;
        }

        public CharacterNode get(char c) {
            return childrenMap.get(c);
        }
    }

Could try maybe something like this…?

String scan is pretty fast - I imagine good compared to autoboxing the char, making the hash, and whatever hoops HashMap jumps through. O(1) sounds great but if you ever step through library implementations for simple collections it’s amazing the complex nonsense they get up to sometimes! String operations on the other hand are always highly tuned.

I’d try it but don’t have a copy of SOWPODS laying around :smiley:

    private static class Node {
    	protected String childrenChars = "";
    	protected ArrayList<CharacterNode> childrenNodes;
    	
        public Node put(char c) {
            CharacterNode n = get(c);
            if (n == null) {
            	childrenChars += c;
                n = new CharacterNode(c);
            	childrenNodes.add(n);
            }
            return n;
        }

        public CharacterNode get(char c) {
        	int index = childrenChars.indexOf(c);
        	return (index == -1) ? null : childrenNodes.get(index);
        }
    }

Uggh. Something can only be “supposedly fastest” for a very narrowly defined set of operations/assumptions.

If you care about speed and size here’s some thoughts.

The dictionary is dynamic because obviously it changes at runtime rather than being fixed. Ya know, like: foobaz suddenly becomes a legal word. The real cost of dynamic data-structures is visiting randomize memory. You mention android so it would be sane to reduce memory size. Speed? How many queries do you expect to perform per game “frame”? Do you need an AI to figure out legal words it can create? If SOWPODS is the real target then the dataset has a fair set of limitations like words are only on 2-15 characters (assuming wikipedia is correct). If you really wanted to be able to check lots of potential words as being legal per frame then bloom filter could be a reasonable idea.

It probably would be if you implemented it properly. Linear search through a list at each level is not the classic trie implementation. It should be a single array lookup. The problem then is memory usage; I’ve found that if you have an expensive offline building process which generates the trie and then combines isomorphic nodes, it takes about as much space as the original word list but performs like a trie.

OK, here is the (I think) correct implementation - I look up the instances from an array, with the index calculated as char % ALPHABET_LENGTH.


    private static class ArrayNode2 {
        protected static int instanceCount = 0;
        private static final int ALPHABET_LENGTH = 52;
        protected ArrayCharacterNode2[] children = new ArrayCharacterNode2[ALPHABET_LENGTH];

        public ArrayNode2 put(char c) {
            ArrayCharacterNode2 n = get(c);
            if (n == null) {
                n = new ArrayCharacterNode2(c);
                children[c % ALPHABET_LENGTH] = n;
            }
            return n;
        }

        public ArrayCharacterNode2 get(char c) {
            return children[c % ALPHABET_LENGTH];
        }
    }

This gives these results (I increased the size of the test, btw):

EDIT: I updated the results after increasing the memory for the process.

ListTrie took 428 right:wrong 172880:345760
MapTrie took 74 right:wrong 172880:345760
ArrayTrie took 70 right:wrong 172880:345760 //richierich method
ArrayTrie2 took 61 right:wrong 172880:345760 //modulus method
Binary took 134 right:wrong 172880:345760

Twice as fast is good enough for me :slight_smile:

Very clever - nice to see incremental improvements occurring!

But why is the alphabet size 52 there and not 26? For mixed-case letters? In which case is it assuming that lowercase and uppercase are contiguous unicode ranges?

Very interesting thread - I’d never looked at tries before although I’d heard of them! With sensible hat on I agree with Roquen that going with the low-tech/low-memory solution (i.e. binary search the basic list) might be best given how similar the speed is for the data volume expected…

52 because I am one dumb-assed programmer. I was working from a different dictionary this AM where all the words are uppercase, but my test misspelled words are word + ‘n’ (lowercase) - works fine for the other methods which test on the actual character at each node, but for the mod method, an invalid character causes you to dive down the wrong branch of the tree. I thought it was an error in the dictionary and only later realized it was my test code. I should check and throw an InvalidCharacterException in such cases :slight_smile:

Anyways, now that the arrays are 26 long the trie goes even faster (not sure why, though) - the new method is about 40% faster than it was.

If it was just spell checking binary would be enough, but i need to work out all possible words for random sets - a trie is much more suited (faster and it also lets you prune out many possible candidate words without having to test them).

You can fit more Nodes in your CPU cache, leading to less cache misses. Memory latency is your bottleneck, so reducing your memory footprint boosts performance.

My next improvement is to not generate meaningless new arrays on leaf nodes, should cut down the cache misses a bit more.

Yes that is what I meant in my previous post, I will only create the array if I have to write into it.

At the cost of complicating your code, it’s probably way faster if you traverse the tree 2 chars at a time. Your lookup tables would be 26*26=676 bytes - they easily fit in (4K) pages.

Why would the lookup table entries be one byte in length?

Excuse me, multiply that by 4, as it’s a table of references.

(to my best knowledge, even in 64bit JVMs, references are 4 byte pointers, multiplied by 8 to calculate the memory address)

Thought so… The lookup table would be 26 x 27 x 4 = 2808 bytes. It needs an extra row for odd-lengthed words. I might give it a try this weekend.

64-bit HotSpot defaults to compressed-oops, so yeah 4 bytes/ref (the headers per object are bigger).

Well, I’ve built it up into my own little scrabble cheat program… If you write the cheat program yourself, is it still cheating?

Compared to spell checking, finding words from random letters is really slow. I can spell check the whole dictionary plus double that in spelling errors in 80 milliseconds, but to find thge words for one set of nine letters takes 10 to 15 milliseconds.

Sigh. That’s why I asked about AI. True or false for a word is a different problem then a specialized reg-exp search.

In my original post I did say " I need fast dictionary lookup and also to find words from random sets [of letters]".

Indeed you did but it was lost in the noise. In this case it doesn’t matter but if you were writing a game that required this functionality then matching is the hard problem and the one to address and not the true/false question.