Flaw in current neural networks

I wanted to post this on a coding train episode on youtube about neural networks but I’ve been banned from commenting (again!) —

There is a basic problem with conventional neural networks. There are too many weighted sums operating off a small set of nonlinearized values. The outputs of the weighed sums then are correlated/entangled with each other. Aside from anything else this is an inefficient use of weight parameters:
https://discourse.numenta.org/t/non-linearity-sharing-in-deep-neural-networks-a-flaw/6033

There is a solution using random projections, or actually other projections that are faster to calculate as long as they have some specific properties.

Also the linear associative memory (AM) aspect of the weighted sum is poorly know, people should remember it. Linear AM does come with a lot of provisos, however in conjunction with nonlinear functions it becomes a more general type of associative memory. I sort of half explained it here:
https://discourse.numenta.org/t/towards-demystifying-over-parameterization-in-deep-learning/5985

2 Likes

this sounds very interesting,
could you show this as a processing example project
or even as a tutorial?

I’m being a LA at the moment and not writing much code. Just ruminating a bit.
I did show this related associative memory thing before:
www.gamespace.eu5.org/associativememory/index.html
You can just download the page to get the JS/P5 code.
The code is on a free webserver and I think that company sometimes plays games with where things link to but the link usually works.

If I understand, you are saying:

  1. “conventional” neural networks are inefficient.
  2. you can solve this problem (inefficiency) with “other projections that are faster to calculate as long as they have some specific properties.”

Two thoughts:

Coding Train is for (primarily) introductory learners to learn coding. That is the audience. Their goal is not to learn the most efficient techniques, but to learn basic / introductory techniques – if efficiency was the primary goal then it wouldn’t make sense to use Java / JavaScript over e.g. C++.

It is also a tutorial how-to series – if you want to make a suggested improvement to the material, that suggestions should be in the form of alternate instructions that someone could actually follow, not a suggestion that beginners do their own research in techniques that are “poorly known” with “a lot of provisos.”

An associative memory tutorial for Processing / P5 sounds interesting – if one doesn’t exist, perhaps you should develop one, or partner with someone to do that! I recall you posting the Auto-Associative demo earlier, and I found it interesting (although a bit confusing).

3 Likes

I use processing, I can post here regardless I would guess. I may use processing java again as well as processing JS. I had some kind of issue with multithreaded image loading with the Java version, maybe I forgot to use some preload function.
Point 2, I said in the link you can use random projections which take O(nln(n)). A bit slow. After thinking about it there may be faster projections you could use.

In terms of a tutorial it looks like presenting one on the weighted sum would be worthwhile. There are quite a few aspects.

  1. The dot product and angular distance
  2. Linear associative memory (AM)
  3. Under capacity linear AM giving error correction by repetition
  4. Over capacity giving recall+Gaussian noise, but still close in angular distance
  5. Conversion to more general associative memory using nonlinear functions
  6. The linear classifier behavior.
  7. Non-linear classifier with prior application of nonlinear functions.
  8. The central limit theorem and the weighed sum
  9. Adaptive filter viewpoint of the weighted sum
  10. Correlation learning viewpoint
  11. Lots of related math like simulations equations up to the Moore–Penrose inverse.
    That would actually be the basics of artificial neural networks. Which ought to be extremely well known at this stage.
    More information: https://github.com/FALCONN-LIB/FFHT/issues/26
1 Like

Version in Processing:

// Code in Processing www.processing.org
final int EDGE=32;    // image edge size
final int DENSITY=2;  // neural network density
final int DEPTH=5;    // neural network depth
final int MUTATIONS=25;   // number of items to mutate during optimization
final float PRECISION=25f;// mutation strength  
final int THREADS=4;      // number of training threads
final int VECLEN=EDGE*EDGE*4;
volatile boolean shouldRun;  // All threads see this in a correct mutual way.
boolean teach;
float[] work=new float[VECLEN];
float[][] imgVectors;
java.util.Vector<Thread> threadList=new java.util.Vector<Thread>(); // active threads
XHNet10 parent=new XHNet10(VECLEN, DENSITY, DEPTH);  // neural network
float parentCost=Float.POSITIVE_INFINITY;

void setup() {
  size(300, 300);
  frameRate(3);
  File[] list=listFiles("data/");
  imgVectors=new float[list.length][VECLEN];
  float rsc=1f/127.5f;
  for (int i=0; i<list.length; i++) {
    PImage img=loadImage(list[i].toString());
    int pos=0;
    for (int y=0; y<EDGE; y++) {
      for (int x=0; x<EDGE; x++) {
        int c=img.get(x, y);
        float r=rsc*((c & 255)-127.5f);
        float g=rsc*(((c>>8) & 255)-127.5f);
        float b=rsc*(((c>>16) & 255)-127.5f);
        float av=0.3333333f*(r+g+b);
        imgVectors[i][pos++]=r;
        imgVectors[i][pos++]=g;
        imgVectors[i][pos++]=b;
        imgVectors[i][pos++]=av;
      }
    }
  }
}  

void displayVector(float[] vec) {
  int pos=0;
  for (int y=0; y<EDGE; y++) {
    for (int x=0; x<EDGE; x++) {
      int r=constrain(round(vec[pos++]*127.5f+127.5f), 0, 255);
      int g=constrain(round(vec[pos++]*127.5f+127.5f), 0, 255);
      int b=constrain(round(vec[pos++]*127.5f+127.5f), 0, 255);
      pos++;
      int c=r | (g<<8) | (b<<16) | 0xff000000;
      for (int i=0; i<8; i++) {
        for (int j=0; j<8; j++) {
          set(j+x*8+22, i+y*8+44, c);
        }
      }
    }
  }
}  

synchronized void updateParent(XHNet10 child, float childCost) {
  if (childCost<parentCost) {
    parentCost=childCost;  // should really wrap with synchronize, because another thread looks at it
    arrayCopy(child.weights, parent.weights);// but no real harm can happen
  } else {
    arrayCopy(parent.weights, child.weights);
  }
  for (int i=0; i<MUTATIONS; i++) {
    int r=int(random(child.weights.length));
    float v=child.weights[r];
    float m=2f*exp(-random(PRECISION));
    if (random(-1f, 1f)<0f) m=-m;
    m+=v;
    if (m<-1f) m=v;
    if (m>1f) m=v;
    child.weights[r]=m;
  }
}  

void threadTrain() {
  threadList.addElement(Thread.currentThread());
  XHNet10 child=new XHNet10(VECLEN, DENSITY, DEPTH);
  float[] res=new float[VECLEN];
  while (shouldRun) {
    float childCost=0f;
    for (int i=0; i<imgVectors.length; i++) {
      child.recall(res, imgVectors[i]);
      for (int j=0; j<VECLEN; j++) {
        float d=res[j]-imgVectors[i][j];
        childCost+=d*d;
      }
    }
    updateParent(child, childCost);
  }
}  

void threadRecall() {
  threadList.addElement(Thread.currentThread());
  int count=0;
  while (shouldRun) {
    parent.recall(work, imgVectors[count++]);
    if (count==imgVectors.length) count=0;
    try { 
      Thread.sleep(2000);
    }
    catch(Exception rte) {
    }
  }
}  

void threadNoise() {
  threadList.addElement(Thread.currentThread());
  while (shouldRun) {
    for (int i=0; i<VECLEN; i++) {
      work[i]=random(-1f, 1f);
    }
    parent.recall(work, work);
    try { 
      Thread.sleep(2000);
    }
    catch(Exception rne) {
    }
  }
}  

void keyPressed() {
  if ((key=='s' || key=='S') && shouldRun) {
    shouldRun=false;
    teach=false;
    try {
      for (Thread t : threadList) t.join();
    }
    catch(Exception te) {
    }
    threadList.clear();
  }  
  if ((key=='t' || key=='T') && !shouldRun) {
    shouldRun=true;
    teach=true;
    for (int i=0; i<THREADS; i++) thread("threadTrain");
  } 
  if ((key=='r' || key=='R') && !shouldRun) {
    shouldRun=true;
    thread("threadRecall");
  } 
  if ((key=='n' || key=='N') && !shouldRun) {
    shouldRun=true;
    thread("threadNoise");
  }
}

void draw() {
  background(0);
  text("Train 'T',  Recall 'R',  Recall Noise 'N', Stop 'S'", 2, 20);
  text("Cost: "+parentCost, 2, 40);
  if (shouldRun && !teach) displayVector(work);
}

class XHNet10 {
  int vecLen;
  int density;
  int depth;
  float layerScale;
  float[] weights;
  float[] workA;
  float[] workB;

  // vecLen must be an int power of 2 (2,4,8,16,32,64...)
  XHNet10(int vecLen, int density, int depth) {
    this.vecLen=vecLen;
    this.density=density;
    this.depth=depth;
    layerScale=2f/sqrt(vecLen*density);
    weights=new float[3*vecLen*density*depth];
    workA=new float[vecLen];
    workB=new float[vecLen];  
    for (int i=0; i<weights.length; i++) {
      weights[i]=random(-1f, 1f);
    }
  }

  void recall(float[] result, float[] input) {
    adjust(workA, input, 4f*layerScale/vecLen);  //Normalize
    signFlip(workA, 123456);  // Hash based random sign flip
    whtRaw(workA);           // +WHT = Random Projection
    int wtIndex=0;
    int i=0;       // depth counter
    while (true) { // depth loop
      zero(result);
      for (int j=0; j<density; j++) {  // density loop

        for (int k=0; k<vecLen; k++) {     // premultiply by weights
          workB[k]=workA[k]*weights[wtIndex++];
        }  
        whtRaw(workB); // premultiply + Walsh Hadamard transform = Spinner projection
        for (int k=0; k<vecLen; k++) { // switch slope at zero activation function
          if (workB[k]<0f) {                          
            result[k]+=workB[k]*weights[wtIndex];    // Slope A
          } else {
            result[k]+=workB[k]*weights[wtIndex+1];  // Or Slope B
          }  
          wtIndex+=2;
        }
      }  // density loop end
      i++;
      if (i==depth) break;  // depth loop end
      scale(workA, result, layerScale);
    }  // depth loop continue
    signFlip(result, 654321);
    whtRaw(result); // Final random projection helps if density is low
  }

  // Fast Walsh Hadamard transform with no scaling    
  void whtRaw(float[] vec) {
    int i, j, hs = 1, n = vec.length;
    while (hs < n) {
      i = 0;
      while (i < n) {
        j = i + hs;
        while (i < j) {
          float a = vec[i];
          float b = vec[i + hs];
          vec[i] = a + b;
          vec[i + hs] = a - b;
          i += 1;
        }
        i += hs;
      }
      hs += hs;
    }
  }

  // recomputable random sign flip of the elements of vec 
  void signFlip(float[] vec, int h) {
    for (int i=0; i<vec.length; i++) {
      h*=0x9E3779B9;
      h+=0x6A09E667;
      // Faster than -  if(h<0) vec[i]=-vec[i];
      vec[i]=Float.intBitsToFloat((h&0x80000000)^Float.floatToRawIntBits(vec[i]));
    }
  }

  void adjust(float[] x, float[] y, float scale) {
    float sum = 0f;
    int n=x.length;
    for (int i = 0; i < n; i++) {
      sum += y[i] * y[i];
    }
    float adj = scale/ (float) Math.sqrt((sum/n) + 1e-20f);
    scale(x, y, adj);
  }  

  void scale(float[] res, float[] x, float scale) {
    for (int i=0, n=res.length; i<n; i++) {
      res[i]=x[i]*scale;
    }
  }

  void zero(float[] x) {
    for (int i=0, n=x.length; i<n; i++) x[i]=0f;
  }
}  

There is a test data folder here:
https://github.com/S6Regen/RP_AUTO_RESNET

2 Likes

I’m looking to try this code, but File[] list = listFiles(“data/”) isn’t working for me. I have the jpgs in my data folder. I imported java.io.File but the listFiles() function isn’t available or something.

The method listFiles() is in the examples but not in the reference as far as I can see. Also dataPath("").
You could try:

void setup() {
  size(300, 300);
  frameRate(3);
  File[] list=null;
  try {
    File f=new File(dataPath(""));
    list=f.listFiles();
  }
  catch(Exception fe) {
  };
  imgVectors=new float[list.length][VECLEN];
  float rsc=1f/127.5f;
  for (int i=0; i<list.length; i++) {
    PImage img=loadImage(list[i].toString());
    int pos=0;
    for (int y=0; y<EDGE; y++) {
      for (int x=0; x<EDGE; x++) {
        int c=img.get(x, y);
        float r=rsc*((c & 255)-127.5f);
        float g=rsc*(((c>>8) & 255)-127.5f);
        float b=rsc*(((c>>16) & 255)-127.5f);
        float av=0.3333333f*(r+g+b);
        imgVectors[i][pos++]=r;
        imgVectors[i][pos++]=g;
        imgVectors[i][pos++]=b;
        imgVectors[i][pos++]=av;
      }
    }
  }
}  

Or maybe you define the data folder path some different way for Windows, which I wouldn’t know about.

Or just bluntly:

void setup() {
  size(300, 300);
  frameRate(3);
  imgVectors=new float[10][VECLEN];
  float rsc=1f/127.5f;
  for (int i=0; i<10; i++) {
    PImage img=loadImage(Integer.toString(i+1)+".jpg");
    int pos=0;
    for (int y=0; y<EDGE; y++) {
      for (int x=0; x<EDGE; x++) {
        int c=img.get(x, y);
        float r=rsc*((c & 255)-127.5f);
        float g=rsc*(((c>>8) & 255)-127.5f);
        float b=rsc*(((c>>16) & 255)-127.5f);
        float av=0.3333333f*(r+g+b);
        imgVectors[i][pos++]=r;
        imgVectors[i][pos++]=g;
        imgVectors[i][pos++]=b;
        imgVectors[i][pos++]=av;
      }
    }
  }
}  
1 Like

Blunt works. I just tried the code. I let it train down to cost: 300 and then wanted to see what it actually learned. The anime chicas came into view one by one. The images were as if a small .jpg had been zoomed until you could see individual pixels, but they were very recognizable. I then ran it again down to cost: 18,000 and hit recall. This time the images were much poorer quality. I’m not entirely sure what is going on but I get it that longer training gets better results.

Thanks for posting it.