SWNet16 Neural Network

The idea is to fuse multiple small width layers (ie. width 16 layers) into one larger layer using the one-to-all connectivity of a fast transform. And then stack those layers into a neural network.

There are a few technical issues such as spectral de-biasing at the input and output of the neural network, because fast transforms tend to pick out low frequency information excessively. That is dealt with by sub-random sign flips.

int c1;
int c2;
float[][]ex = new float[512][256];
float[] wk = new float[256];
int repeats = 25;
SWNet16 neuralNetwork;
// Network constructor SwitchNet(vector size (16,32,64,...), depth, learning rate)
void setup() {
  size(400, 400);
  neuralNetwork = new SWNet16(256, 2, 0.01);
  c1 = color(255,255,0);
  c2 = color(0,255,255);
  for (int i = 0; i < 512; i += 2) {
    for (int j = 0; j < 256; j++) {
      float t = j * (0.03+0.0001*i);
      ex[i][j] = cos(t); //input
      ex[i + 1][j] = sin(t); //target
    }
  }
  textSize(16);
}

void draw() {
  background(0);
 // loadPixels();
  for (int i = 0; i < ex.length; i+=2) {
      neuralNetwork.train(ex[i+1], ex[i]);
  }
  int current=frameCount&255;
  neuralNetwork.recall(wk,ex[current*2]);
  fill(c1);
  for (int i = 0; i <256; i++) {
      set( i,100-int(75*ex[current*2][i]), c1);
    
  }
  fill(c2);
  for (int i = 0; i <256; i++) {
      set( i,100-int(75*wk[i]), c2); 
  }
  //updatePixels();

  text("Input cosine wave. Recall sine wave.", 5, 200);
  text("Iterations: " + frameCount, 5, 230);
}



final class SWNet16 {
  final int vecLen;
  final int depth;
  final float[][] wts;
  final float[][] sur;
  final float[] work;
  final float[] mini;
  final float[] flipA;
  final float[] flipB;
  float rate;
  public SWNet16(int vecLen, int depth, float rate) {
    this.vecLen = vecLen;
    this.depth=depth;
    this.rate=rate;
    wts = new float[depth][32*vecLen];
    sur=new float[depth][vecLen];
    work=new float[vecLen];
    mini=new float[16];
    flipA=new float[vecLen];
    flipB=new float[vecLen];
    for (int i = 0; i < depth; i++) {
      for (int j=0; j<wts[i].length; j++) {
        wts[i][j]=randomGaussian()*0.25;
      }
    }
    setFlips(7, 277);
  }

  public void recall(float[] result, float[] input) {
    for (int i=0; i<vecLen; i++) result[i]=input[i]*flipA[i];
    whtN(result);
    for (int i=0; i < depth; i++) {
      if (result==work) System.arraycopy(work, 0, sur[i], 0, vecLen);
      layer16(result, wts[i]);
      whtN(result);
    }
    for (int i=0; i<vecLen; i++) result[i]*=flipB[i];
  }

  public void train(float[] target, float[] input) {
    recall(work, input);
    for (int i=0; i<vecLen; i++) {
      work[i]=flipB[i]*(target[i]-work[i])*rate;
    }
    for (int i = depth-1; i >=0; i--) {
      whtN(work);
      layer16Back(sur[i], wts[i], work);
    }
  }

  public void layer16(float[] x, float[] w) {
    for (int i=0; i<vecLen; i+=16) {
      for (int j=0; j<16; j++) {
        mini[j]=x[i+j];
        x[i+j]=0f;
      }
      for (int j=0; j<16; j++) {
        float v=mini[j];
        int idx=32*(i+j);
        if (v<0f) idx+=16;
        for (int k=0; k<16; k++) {
          x[i+k]+=v*w[idx+k];
        }
      }
    }
  }

  public void layer16Back(float[] x, float[] w, float[] e) {
    for (int i=0; i<vecLen; i+=16) {
      for (int j=0; j<16; j++) {
        mini[j]=e[i+j];
      }
      for (int j=0; j<16; j++) {
        float v=x[i+j];
        int idx=32*(i+j);
        if (v<0f) idx+=16;
        float es=0f;
        for (int k=0; k<16; k++) {
          w[idx+k]+=v*mini[k];
          es+=w[idx+k]*mini[k];
        }
        e[i+j]=es;
      }
    }
  }

  public void setFlips(int q1, int q2) {
    int r1=0, r2=0, mask=vecLen-1;
    for (int i=0; i<vecLen; i++) {
      r1=(r1+q1)&mask;
      r2=(r2+q2)&mask;
      flipA[i]=(2*Integer.bitCount(r1)&2)-1;
      flipB[i]=(2*Integer.bitCount(r2)&2)-1;
    }
  }
  public void wtScale(float sc) {
    for (int i = 0; i < depth; i++) {
      for (int j=0; j<wts[i].length; j++) {
        wts[i][j]*=sc;
      }
    }
  }

  public void whtN(float[] x) {
    int hs = 1;
    while (hs < vecLen) {
      int i = 0;
      while (i < vecLen) {
        int j = i + hs;
        while (i < j) {
          float a = x[i];
          float b = x[i + hs];
          x[i] = a + b;
          x[i + hs] = a - b;
          i += 1;
        }
        i += hs;
      }
      hs += hs;
    }
    float scale=1.0/(float)Math.sqrt(vecLen);
    for (int i=0; i<vecLen; i++) x[i]*=scale;
  }

  public float getRate() {
    return rate;
  }
  public void setRate(float r) {
    rate=r;
  }
}