Here we go:
I have 3 classes, color, palette and image.
The color:
package ici.data;
/**
 * This class represents a color to be inserted and used in the color palette.
 * @author Rafael Carvallo
 */
public class Color3i{
  public int r,g,b;
  public Color3i(){
    r = g = b = 0;
  }
  public Color3i(int r, int g, int b){
    this.r = r;
    this.g = g;
    this.b = b;
  }
  public String toString(){
    return r+" "+g+" "+b;
  }
}
The palette:
package ici.data;
import java.awt.image.*;
import java.util.*;
/**
 * Class that represents the palette for the image.
 * The first color is always transparent in the drawing. 
 * @author Rafael Carvallo
 */
public class ICPalette {
  Vector colores;
  public ICPalette(){
    colores = new Vector();
  }
  public IndexColorModel getICM(){
    int cc = colores.size(); 
    if(cc > 0){
      byte[] data = new byte[cc*3];
      int k = 0;
      for(Enumeration i = colores.elements();i.hasMoreElements();){
        Color3i c3i = (Color3i)i.nextElement();
        data[k++] = (byte)(c3i.r);
        data[k++] = (byte)(c3i.g);
        data[k++] = (byte)(c3i.b);
      }
      return new IndexColorModel(8,cc,data,0,false,0);
    }
    return null;
  }
  public void insertColor(int r, int g, int b){
    Color3i c3i = new Color3i(r,g,b);
    colores.addElement(c3i);   
  }
  public void deleteColor(int index){
    colores.removeElementAt(index);
  }
  public void updateColor(int index, int r, int g, int b){
    Color3i c3i = (Color3i)(colores.elementAt(index));
    c3i.r = r;
    c3i.g = g;
    c3i.b = b;
  }
  public int getColorCount(){
    return colores.size();
  }
  public Vector getColorVector(){
    return colores;
  }
  public Color3i getColor(int index){
    return (Color3i)colores.elementAt(index);
  }
}
The image:
package ici.data;
import java.awt.*;
import java.awt.image.*;
import java.util.*;
/**
 * This class represents an image independent from the palette.
 * @author Rafael Carvallo
 *
 */
public class ICImage {
  int[] data;
  int w,h;
  public ICImage(){
  }
  public ICImage(int W, int H) {
    w = W;
    h = H;
    data = new int[w*h];
  }
  /** Generates the image with the data defined and the the palette passed.
   * @param pal An ICPalette to be used in the creation of the image.
   */   
  public BufferedImage getImage(ICPalette pal){
    IndexColorModel cm1 = pal.getICM();
    BufferedImage im = new BufferedImage(w,h,BufferedImage.TYPE_BYTE_INDEXED,cm1);
    WritableRaster rast=im.getRaster();
    rast.setPixels(0,0,w,h,data); 
    return im;
  }
  public int[][] getPixelArray(){
    //generate the array
    int[][] arr = new int[w][h];
    for(int i = 0; i<w;i++){
      for(int j = 0; j<h;j++){
        arr[i][j] = data[i+j*h];
      }
    }
    return arr;
  }
  public void setPixelArray(int[][] arr){
    //extract all the information
    w = arr.length;
    h = arr[0].length;
    data = new int[w*h];
    for(int i = 0; i<w;i++){
      for(int j = 0; j<h;j++){
        data[j*h+i] = arr[i][j];
      }
    }
  }
  public void setPixel(int x, int y, int index){
    data[y*h+x]=index;
  }
  public Dimension getSize(){
    return new Dimension(w,h);
  }
}
A test:
import ici.data.ICImage;
import ici.data.ICPalette;
import java.awt.*;
import javax.swing.JFrame;
class Test extends Canvas 
{
  public static final int WIDTH=250;
  public static final int HEIGHT=250;
  Image i1, i2;
  public Test(){
    ICPalette palette1 = new ICPalette();
    palette1.insertColor(255,255,255); //white transparent
    palette1.insertColor(255,0,0); //red
    palette1.insertColor(200,0,0); //less red
    palette1.insertColor(150,0,0); //less red
    palette1.insertColor(100,0,0); //less red
    ICPalette palette2 = new ICPalette();
    palette2.insertColor(255,255,255); //white transparent
    palette2.insertColor(0,0,255); //blue
    palette2.insertColor(0,0,200); //less blue
    palette2.insertColor(0,0,150); //less blue
    palette2.insertColor(0,0,100); //less blue
    ICImage image = new ICImage(100,100);
    for (int i = 0; i < 100; i++) {
      for (int j = 0; j < 100; j++) {
        if(i < 25){
          image.setPixel(i,j,1);
        }
        else if(i > 75){
          image.setPixel(i,j,2);
        }
        else if(j < 25){
          image.setPixel(i,j,3);
        }
        else if(j > 75){
          image.setPixel(i,j,4);
        }
        else{
          image.setPixel(i,j,0);        
        }
      }
    }
    i1 = image.getImage(palette1);
    i2 = image.getImage(palette2);
  }
  public void paint(Graphics g){
    g.drawImage(i1,40,40,this);
    g.drawImage(i2,70,80,this);
  }
  public static void main(String[] args) 
	{
    JFrame f = new JFrame();
    f.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);    
    f.getContentPane().add(new Test());
    f.setSize(WIDTH,HEIGHT);
    f.setVisible(true);
  }
}
Hope this helps you.
Rafael.-