Test results (that nobody is waiting for ;))
normal grid: 58906ms locality grid: 26688ms normal grid: 59078ms locality grid: 26906ms
The dataset is 1024x256x1024 cells (256MB) so the cache is swamped anyway. You can clearly see that with this locality you eliminate most of the cache misses, yet those remaining, have a disproportional impact, thus giving ‘only’ a factor ~2 performance boost.
public class LocalityGrid
{
private static final int localStride = 8 * 8 * 8;
private static final int localStrideMulAsShift = 3 + 3 + 3;
private static final int xShift = 10 - 3;
private static final int yShift = 8 - 3;
private static final int zShift = 10 - 3;
private static final int xzShift = xShift + zShift;
public static byte[] create()
{
return new byte[1 << (xShift + yShift + zShift + localStrideMulAsShift)];
}
public static int countNeighbours(byte[] grid, int x, int y, int z)
{
int count = 0;
count += get(grid, x - 1, y - 1, z - 1);
count += get(grid, x, y - 1, z - 1);
count += get(grid, x + 1, y - 1, z - 1);
count += get(grid, x - 1, y, z - 1);
count += get(grid, x, y, z - 1);
count += get(grid, x + 1, y, z - 1);
count += get(grid, x - 1, y + 1, z - 1);
count += get(grid, x, y + 1, z - 1);
count += get(grid, x + 1, y + 1, z - 1);
count += get(grid, x - 1, y - 1, z);
count += get(grid, x, y - 1, z);
count += get(grid, x + 1, y - 1, z);
count += get(grid, x - 1, y, z);
//count += get(grid, x, y, z);
count += get(grid, x + 1, y, z);
count += get(grid, x - 1, y + 1, z);
count += get(grid, x, y + 1, z);
count += get(grid, x + 1, y + 1, z);
count += get(grid, x - 1, y - 1, z + 1);
count += get(grid, x, y - 1, z - 1);
count += get(grid, x + 1, y - 1, z + 1);
count += get(grid, x - 1, y, z + 1);
count += get(grid, x, y, z + 1);
count += get(grid, x + 1, y, z + 1);
count += get(grid, x - 1, y + 1, z + 1);
count += get(grid, x, y + 1, z + 1);
count += get(grid, x + 1, y + 1, z + 1);
return count;
}
public static byte get(byte[] grid, int x, int y, int z)
{
return grid[toIndex(x, y, z)];
}
public static void set(byte[] grid, int x, int y, int z, byte cell)
{
grid[toIndex(x, y, z)] = cell;
}
private static int toIndex(int x, int y, int z)
{
int iChunk = (x >> 3);
iChunk |= (y >> 3) << xzShift;
iChunk |= (z >> 3) << xShift;
int index = iChunk << localStrideMulAsShift;
index |= ((z & 7) << 3) | ((y & 7) << 6) | (x & 7);
return index;
}
}
public class NormalGrid
{
private static final int xShift = 10;
private static final int yShift = 8;
private static final int zShift = 10;
private static final int xzShift = xShift + zShift;
public static byte[] create()
{
return new byte[1 << (xShift + yShift + zShift)];
}
public static int countNeighbours(byte[] grid, int x, int y, int z)
{
int count = 0;
count += get(grid, x - 1, y - 1, z - 1);
count += get(grid, x, y - 1, z - 1);
count += get(grid, x + 1, y - 1, z - 1);
count += get(grid, x - 1, y, z - 1);
count += get(grid, x, y, z - 1);
count += get(grid, x + 1, y, z - 1);
count += get(grid, x - 1, y + 1, z - 1);
count += get(grid, x, y + 1, z - 1);
count += get(grid, x + 1, y + 1, z - 1);
count += get(grid, x - 1, y - 1, z);
count += get(grid, x, y - 1, z);
count += get(grid, x + 1, y - 1, z);
count += get(grid, x - 1, y, z);
// count += get(grid, x, y, z);
count += get(grid, x + 1, y, z);
count += get(grid, x - 1, y + 1, z);
count += get(grid, x, y + 1, z);
count += get(grid, x + 1, y + 1, z);
count += get(grid, x - 1, y - 1, z + 1);
count += get(grid, x, y - 1, z - 1);
count += get(grid, x + 1, y - 1, z + 1);
count += get(grid, x - 1, y, z + 1);
count += get(grid, x, y, z + 1);
count += get(grid, x + 1, y, z + 1);
count += get(grid, x - 1, y + 1, z + 1);
count += get(grid, x, y + 1, z + 1);
count += get(grid, x + 1, y + 1, z + 1);
return count;
}
public static byte get(byte[] grid, int x, int y, int z)
{
return grid[toIndex(x, y, z)];
}
public static void set(byte[] grid, int x, int y, int z, byte cell)
{
grid[toIndex(x, y, z)] = cell;
}
private static int toIndex(int x, int y, int z)
{
int index = x;
index |= y << xzShift;
index |= z << xShift;
return index;
}
}
public class GridBenchmark
{
public static void main(String[] args)
{
for (int i = 0; i < 4; i++)
{
System.out.println("normal grid: " + benchNormalGrid() + "ms");
System.out.println("locality grid: " + benchLocalityGrid() + "ms");
}
}
private static long benchNormalGrid()
{
byte[] grid = NormalGrid.create();
long t0 = System.currentTimeMillis();
for (int x = 1; x < 1023; x++)
{
for (int y = 1; y < 254; y++)
{
for (int z = 1; z < 1023; z++)
{
NormalGrid.countNeighbours(grid, x, y, z);
}
}
}
long t1 = System.currentTimeMillis();
return t1 - t0;
}
private static long benchLocalityGrid()
{
byte[] grid = LocalityGrid.create();
long t0 = System.currentTimeMillis();
for (int x = 1; x < 1023; x++)
{
for (int y = 1; y < 254; y++)
{
for (int z = 1; z < 1023; z++)
{
LocalityGrid.countNeighbours(grid, x, y, z);
}
}
}
long t1 = System.currentTimeMillis();
return t1 - t0;
}
}
Update:
normal ==> 59s 2x 2x 2 ==> 42s 4x 4x 4 ==> 30s 8x 8x 8 ==> 26s 16x16x16 ==> 23s 32x16x32 ==> 23s 32x32x32 ==> 23s
Update:
normal ==> 59s 2x 2x 2 ==> 39s 4x 4x 4 ==> 26s 8x 8x 8 ==> 24s 16x16x16 ==> 20s 32x32x32 ==> 19s
Optimizing instructions for the JIT… (making any changes, like removing local variables will add 3s!!)
//this:
int a = (...) | (...) | (...);
// is SLOWER than
int a = 0;
a |= (...);
a |= (...);
a |= (...);
private static final int bits = 5; // 32x32x32
private static final int bits_plus_1 = bits + 1;
private static final int mask = ~(-1 << bits); // 31
private static final int strideShift = bits + bits + bits;
private static final int xBits = (10 - bits);
private static final int zBits = (10 - bits);
private static final int yBits = (8 - bits);
private static final int xShift = strideShift;
private static final int zShift = xBits + strideShift;
private static final int yShift = (xBits + zBits) + strideShift;
public static byte[] create()
{
return new byte[1 << (xBits + yBits + zBits + strideShift)];
}
// set & get
private static int toIndex(int x, int y, int z)
{
int global = 0;
global |= (x >> bits) << xShift;
global |= (y >> bits) << yShift;
global |= (z >> bits) << zShift;
int local = 0;
local |= (z & mask) << bits;
local |= (y & mask) << bits_plus_1;
local |= (z & mask);
return global | local;
}
59s / 19s = factor 3.1 8)