/*****************************************************************************
 *                                                                           *
 * FOX Controller Library, Version 1.3.3.                                    *
 *                                                                           *
 * Copyright (C) 1998,1999,2000 Russell Smith (rl.smith@auckland.ac.nz)      *
 *                                                                           *
 * The FOX Controller Library is free software; you can redistribute it      *
 * and/or modify it under the terms of the GNU Library General Public        *
 * License as published by the Free Software Foundation; either version      *
 * 2 of the License, or (at your option) any later version.                  *
 *                                                                           *
 * The FOX Controller Library is distributed in the hope that it will be     *
 * useful, but WITHOUT ANY WARRANTY; without even the implied warranty of    *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU         *
 * Library General Public License for more details.                          *
 *                                                                           *
 * You should have received a copy of the GNU Library General Public         *
 * License along with the Fox Controller Library; see the file COPYING.LIB.  *
 * If not, write to the Free Software Foundation, Inc., 59 Temple Place -    *
 * Suite 330, Boston, MA 02111-1307, USA.                                    *
 *                                                                           *
 *****************************************************************************/

/*

FOXN: CMAC with vector eligibility
----------------------------------

We want to avoid having two or more weight indexes point to the same weight
for any input, so we dont have to do tricky things with the eligibility
computations to take the multiple references into account.

In the "normal" cmac the weight index is hash function of the virtual
address (made up of the input receptive field numbers) and the AU
number. The weight index was between 0..wsize-1. We could instead compute the
weight index as a hash function of the virtual address alone, hashed into the
range 0..wsize-1, PLUS wsize times the AU number, so the weight index is
between 0..(wsize*C)-1. The weight table would be divided into `C' 
wsize-blocks. This has the advantage that the weight indexes would always be
distinct.

BUT this scheme is not quite perfect because the basis functions for each AU
will look the same, so a hash collision (i.e. a basis function localized in
more than one useful area) will not "spread itself around" but will instead
be concentrated in the AU basis functions. The solution is to make the AU
number a parameter of the hash function also.

However, as C increased the total number of weights would increase, but
the weight requirement actually decreases. Thus we have the user specify the
desired number of weights per output, wsize is computed so that wsize*C is
close to this number.

The "slow" method of updating eligibilities is to perform the filter
calculations for each weight. For accurate comparisons with the fast method
we need to get the effect of "retiring" weights, so we use the aux[] index
value to store the time at which the weight was last accessed.

The "inactive weight" queue is split into two parts, `buffer1' and `buffer2'.
Buffer 1 stores C weight indexes at each buffer position. Each index is the
index of an inactive weight, or -1 if none.
Buffer 2 stores C eligibility vectors at each buffer position, which
correspond to the non -1 weight indexes. Each vector has size `ny', so C*ny
numbers are stored at each buffer position. The vector is ignored if the
corresponding weight index is -1.

NOTE that the eligibility vectors stored in buffer2 should be unique for each
output (the eligibility filter driving force from matrix B is different for
each output), but we wont worry about that yet as we are currently only
supporting one output.

A weight index can only be in the inactive queue once. The eligibility
vector of an inactive weight says its eligibility the last time it was active.
If a weight is _not_ in the inactive queue then it's eligibility is assumed
to be 0.

We store some auxiliary information about each weight index (NOT for each
weight, i.e. for more outputs there is still the same amount of auxiliary
information). The auxiliary information is an integer `i' : if the weight is
in the inactive queue, this is its position, otherwise the weight's stored
value is correct and `i' is -1.

See cmac.cc for additional comments.

*/

#include <string.h>
#include "foxstuff.h"
#include "fox-n.h"

//***************************************************************************
// see the .cc file for a description of what this does

char * GetCMACOverlayDisplacementVector (int n, int p);

//***************************************************************************
// Here is a really hashy hash function from "Numerical Recipes in C",
// 2nd Ed, page 300. It does a pseudo-DES hashing of the given 64 bit word
// (lword,irword). Both 32-bit arguments are returned hashed on all bits.
// This does a much better job of nonlinear pseudo-random bit mixing than
// some other hash functions (like mod-prime or linear congruential) but
// at the expense of speed. It's still pretty quick, though.

#define PDES_ITERS 4

static void pdes_hash (unsigned long &lword, unsigned long &irword)
{
  unsigned long i,ia,ib,iswap,itmph=0,itmpl=0;
  static unsigned long c1 [PDES_ITERS] = {
    0xbaa96887L,0x1e17d32cL,0x03bcdc3cL,0x0f33d1b2L};
  static unsigned long c2 [PDES_ITERS] = {
    0x4b0f3b58L,0xe874f0c3L,0x6955c5a6L,0x55a7ca46L};

  // Perform PDES_ITERS iterations of DES logic, using a simpler
  // (non-cryptographic) nonlinear function instead of DES's.
  for (i=0; i<PDES_ITERS; i++) {
    ia = (iswap=irword)^c1[i];
    itmpl = ia & 0xffff;
    itmph = ia >> 16;
    ib = itmpl*itmpl + ~(itmph*itmph);
    irword = lword^(((ia=(ib>>16)|((ib&0xffff)<<16))^c2[i])+itmpl*itmph);
    lword = iswap;
  }
}


// Test pdes_hash to make sure the integer arithmetic is being done ok.
// If anything is wrong an error message is printed and the program is halted.
static void test_pdes_hash()
{
  unsigned long a,b;
  if (sizeof (a) != 4)
    ZFatalError ("Internal error: sizeof(unsigned long) must be 4.");
  a=99;
  b=99;
  pdes_hash (a,b);
  if ((a != 0xd7f376f0) || (b != 0x59ba89eb))
    ZFatalError ("Internal error: pdes_hash integer arithmetic incorrect.");
}


// Automatic pdes_hash tester
static struct automatic_pdes_hash_tester {
  automatic_pdes_hash_tester() { test_pdes_hash(); };
} test_pdes_hash_automatically;

//***************************************************************************
// Check that we can set a `cftype' to zero by setting all its bits to zero,
// because we want to use memset() to zero out some large arrays.

static void check_zero_cftype()
{
  cftype a;
  a = 3.14159265358979323846;
  memset (&a,0,sizeof(a));
  if (a != 0.0)
    ZFatalError ("Internal error: floating point format not supported.");
}

// Automatic check_zero_cftype caller
static struct automatic_check_zero_cftype {
  automatic_check_zero_cftype() { check_zero_cftype(); };
} call_check_zero_cftype_automatically;

//***************************************************************************
// FoxN

FoxN::FoxN (int num_inputs, CmacInputInfo *input_info,
  int num_weights, int iC, int _ny, cftype *_matA, cftype *_matB,
  cftype *_matC, int _history)
{
  int i;
  unsigned long a;

  valid = 0;

  // set all pointers to NULL so that the destructor can safely destruct a
  // partially constructed object
  weights = NULL;
  windex = NULL;
  buffer1 = NULL;
  buffer2 = NULL;
  atoi = NULL;
  aux = NULL;
  trace = NULL;
#ifdef SLOW_METHOD
  slow_eli = NULL;
#endif

  nin = num_inputs;
  C = iC;
  checkthat (nin > 0 && nin <= CMAC_MAXIN,"bad number of inputs (%d)",nin);
  checkthat (C > 0 && C <= CMAC_MAXC,"bad number of association units (%d)",C);

  // setup input info array
  for (i=0; i<nin; i++) {
    iinfo[i].res = input_info[i].res;
    checkthat (iinfo[i].res >= 2,"resolution for input %d too small",i);
    iinfo[i].min = input_info[i].min;
    iinfo[i].max = input_info[i].max;
    iinfo[i].q = cftype(iinfo[i].res) / (iinfo[i].max - iinfo[i].min);
    checkthat (iinfo[i].q > 0,"bad max / min for input %d",i);
    iinfo[i].parts = (iinfo[i].res-2)/C + 2;
  }

  // Check that the receptive field numbers for all inputs can be packed
  // into a 32-bit word - multiply them together and check for overflow.
  a = iinfo[0].parts;
  for (i=1; i<nin; i++) {
    // check for potential overflow
    checkthat (a < (0xffffffffL / iinfo[i].parts),"too much input resolution");
    a *= iinfo[i].parts;
  }

  // setup weight parameters
  checkthat (num_weights >= C,"number of weights must be > C");
  wsize = num_weights / C;

  // Setup weights.
  checkmemptr (weights = new cftype [wsize * C]);
  memset (weights, 0, wsize*C * sizeof(cftype));

  // setup other arrays
  checkmemptr (windex = new int [C]);
  odv = GetCMACOverlayDisplacementVector (nin, C);
  if (!odv) return;
  checkmemptr (aux = new int [wsize*C]);

  output = 0.0;
  old_output = 0.0;

  // setup eligibility stuff

  ny = _ny;
  checkthat (ny >= 1 && ny <= MAX_NY,"ny must be in the range 1..%d",MAX_NY);
  memcpy (matA,_matA,ny*ny * sizeof(cftype));
  memcpy (matB,_matB,ny*1 * sizeof(cftype));
  memcpy (matC,_matC,ny * sizeof(cftype));

  elimode = 1;
  history = _history;
  checkthat (history > 0,"history must be > 0");
  bufsize1 = history * C;
  bufsize2 = history * C * ny;
  checkmemptr (buffer1 = new int [bufsize1]);
  checkmemptr (buffer2 = new cftype [bufsize2]);
  checkmemptr (atoi = new GetAtoi (ny,matA,history + 2));
  if (!atoi->Valid()) return;
  checkmemptr (trace = new TraceN (ny,matA,matC,history + 2,atoi));
  if (!trace->Valid()) return;

  // set all buffer indexes to -1, else Reset() will try to correct the weights
  memset (buffer1,~0,sizeof(int) * bufsize1);

#ifdef SLOW_METHOD
  checkmemptr (slow_eli = new cftype [wsize * C * ny]);
#endif

  Reset();
  valid = 1;
}


FoxN::~FoxN()
{
  delete[] weights;
  delete[] windex;
  delete[] buffer1;
  delete[] buffer2;
  delete atoi;
  delete[] aux;
  delete trace;
#ifdef SLOW_METHOD
  delete[] slow_eli;
#endif
}


void FoxN::SetEligibilityMode (int _elimode)
{
  myassert (elimode !=2,"can't change eligibility mode, Train() call pending");
  int newelimode = (_elimode != 0);
  if (elimode == newelimode) return;
  if (elimode == 1 && newelimode == 0) Reset();
  elimode = newelimode;
}


void FoxN::Map (cftype *input)
{
  MapGate (input,1);
}


void FoxN::MapGate (cftype *input, int gate)
{
  int i,j;

  Quantize (input);
  GetAssociations();

#ifdef FAST_METHOD
  // If in mapping mode, all weights will be correct.
  if (elimode) {
    myassert (elimode == 1,"called Map() but call to Train() pending");
    elimode = 2;

    // advance buffer head and retire old weights
    time++;
    head += C;
    if (head >= bufsize1) head = 0;
    for (i=0; i<C; i++) {
      j = buffer1[head+i];
      if (j != -1) GetCorrectWeight (j,1);	// retire weight j
    }

    // Get correct values and eligibilities for indexed weights, then add
    // these weights to the buffer. All indexes in `windex' are distinct so we
    // don't acidentally do this processing more that once per weight index.
    for (i=0; i<C; i++) {
      const cftype *eli = GetCorrectWeight (windex[i],0);

      // driving part of eligibility filter - this is just `B'
      cftype *buf = buffer2 + (head+i)*ny;
      if (gate) {
	for (j=0; j<ny; j++) buf[j] = eli[j] + matB[j];
      }
      else {
	for (j=0; j<ny; j++) buf[j] = eli[j];
      }

      buffer1[head+i] = windex[i];
      aux[windex[i]] = head+i;
    }
  }
#endif

#ifdef SLOW_METHOD		// slow eligibility method
  if (elimode) {
    myassert (elimode == 1,"called Map() but call to Train() pending");
    elimode = 2;
    cftype *eptr;
    time++;

    // "retire" old weights, to get same outputs as fast method.
    for (i=0; i < (wsize*C); i++) {
      if ( (aux[i] >= 0) && ((time - aux [i]) >= history) ) {
	eptr = slow_eli + i*ny;
	for (j=0; j<ny; j++) eptr[j] = 0.0;
      }
    }

    // decay all eligibilities
    cftype tmp[MAX_NY];
    eptr = slow_eli;
    for (i=0; i < (wsize*C); i++) {
      MatrixMultiply2 (ny,tmp,matA,eptr);
      memcpy (eptr,tmp,ny * sizeof(cftype));
      eptr += ny;
    }

    // driving part of eligibility filter for active weights
    for (i=0; i<C; i++) {
      eptr = slow_eli + windex[i]*ny;
      for (j=0; j<ny; j++) eptr[j] += matB[j];
    }

    // set time of last access for active weights
    for (i=0; i<C; i++) aux [windex[i]] = time;
  }
#endif

  // remember previous output
  old_output = output;

  // calculate outputs
  output = 0.0;
  for (i=0; i<C; i++) output += weights [windex[i]];
}


void FoxN::Train (cftype error)
{
  myassert (elimode == 2,"out of order call to Train()");
  elimode = 1;

#ifdef FAST_METHOD
  cftype k = error / (cftype (C));
  lasterror = k;
  trace->Next (k);
#endif

#ifdef SLOW_METHOD
  cftype k = error / (cftype (C));
  cftype dot,*eli = slow_eli;
  for (int i=0; i < (wsize*C); i++) {
    dot = 0.0;
    for (int j=0; j<ny; j++) dot += matC[j] * eli[j];
    weights[i] += k * dot;
    eli += ny;
  }
#endif
}


void FoxN::LimitOutput (cftype lr)
{
  myassert (elimode == 2,"out of order call to LimitOutput()");
  cftype winc = - lr * output;
  for (int i=0; i<C; i++) weights [windex[i]] += winc;
}


void FoxN::LimitOutputDerivative (cftype lr)
{
  myassert (elimode == 2,"out of order call to LimitOutputDerivative()");
  cftype winc = - lr * (output - old_output);
  for (int i=0; i<C; i++) weights [windex[i]] += winc;
}


void FoxN::UpperLimit (cftype limit)
{
  myassert (elimode == 2,"out of order call to UpperLimit()");
  if (output > limit) {
    cftype winc = (limit - output)/cftype(C);
    for (int i=0; i<C; i++) weights [windex[i]] += winc;
  }
}


void FoxN::LowerLimit (cftype limit)
{
  myassert (elimode == 2,"out of order call to LowerLimit()");
  if (output < limit) {
    cftype winc = (limit - output)/cftype(C);
    for (int i=0; i<C; i++) weights [windex[i]] += winc;
  }
}


void FoxN::SetOutput (cftype x)
{
  myassert (elimode == 2,"out of order call to SetOutput()");
  cftype winc = (x - output)/cftype(C);
  for (int i=0; i<C; i++) weights [windex[i]] += winc;
}


// We should be able to call Reset() multiple times without trouble.

void FoxN::Reset()
{
  myassert (elimode == 1,"out of order call to Reset()");

#ifdef FAST_METHOD
  // Get the correct values for all weights in the buffer. This will set all
  // buffer indexes to -1, effectively zeroing all the eligibilities.
  // Before doing this we must advance the current time to ensure the weight
  // updates are correct.
  time++;
  head += C;
  if (head >= bufsize1) head = 0;
  for (int i=0; i<bufsize1; i++) {
    if (buffer1[i] != -1) GetCorrectWeight (buffer1[i],1);
  }
#endif

#ifdef SLOW_METHOD
  // set all buffer indexes to -1 and zero all the eligibility vectors
  memset (buffer1,~0,sizeof(int) * bufsize1);
  memset (slow_eli, 0, wsize * C * ny * sizeof(cftype));
#endif

#ifndef NDEBUG
  // Check that all buffer indexes are -1
  for (int i=0; i<bufsize1; i++)
    myassert (buffer1[i] == -1,"buffer not zeroed");
#endif

  time = -1;
  head = -C;
  lasterror = 0.0;
  trace->Reset();

  // clear out aux array: set indexes to -1
  memset (aux,~0,sizeof(int) * wsize*C);
}


void FoxN::ResetAll()
{
  Reset();
  memset (weights, 0, wsize*C * sizeof(cftype));
}


int FoxN::InRange (cftype *input)
{
  int i;
  for (i=0; i<nin; i++) {
    if (input[i] > iinfo[i].max) return 0;
    else if (input[i] < iinfo[i].min) return 0;
  }
  return 1;
}


void FoxN::Quantize (cftype *input)
{
  for (int i=0; i<nin; i++) {
    if (input[i] >= iinfo[i].max) qin[i] = iinfo[i].res-1;
    else {
      if (input[i] <= iinfo[i].min) qin[i] = 0;
#     ifdef USING_HPFLOAT
      else {
	cftype result = (input[i] - iinfo[i].min) * iinfo[i].q;
	qin[i] = (int) result.Double();
      }
#     else
      else qin[i] = int ( (input[i] - iinfo[i].min) * iinfo[i].q );
#     endif
    }
  }
}


int FoxN::Association (int a)
{
  int i,r;
  unsigned long hh,h;		// (hh,h) = 64 bit hash key
  int oo;			// AU receptive field ("overlay") offset
  int addr;			// returned address

  myassert (a>=0 && a<C,"bad argument `a' (%d) in Association()",a);

  // Generate the hash key for association unit 'a': build the hash key
  // based on the receptive field numbers for all inputs.
  h = 0;
  for (i=0; i<nin; i++) {	// for all inputs
    oo = OverlayOffset (i,a);
    r = (qin[i] + oo) / C;	// get receptive field coordinate
    h = h*iinfo[i].parts + r;	// build hash key
  }
  hh=a;				// <-- NOT h=0, see comments in header
  pdes_hash (hh,h);
  addr = (hh % wsize) + a*wsize;
  return addr;
}


void FoxN::GetAssociations()
{
  for (int i=0; i<C; i++) windex[i] = Association (i);
}


const cftype * FoxN::GetCorrectWeight (int i, int old)
{
  static cftype neweli [MAX_NY];
  static cftype zeroeli [MAX_NY] = {0,0,0,0};

  // Is weight in buffer? If not it already has the correct value.
  if (aux[i] < 0) return zeroeli;

  // take weight out of buffer
  buffer1 [aux[i]] = -1;

  // find buffer "timestep" position
  int bpos = aux[i] / C;

  // find time `t' at which weight was last active
  int t,hd;
  hd = head / C;
  t = time + bpos - hd;
  if (old ? (bpos >= hd) : (bpos > hd)) t -= history;

  myassert (t >= 0 && t <= time && t >= (time-history),
	    "bad time in GetCorrectWeight");

  cftype *eli = buffer2 + aux[i] * ny;	// eligibility vector for this weight

  // Find increment to weight (for all outputs), and change in eligiblity.
  if ((time-t)==1) {
    // If weight was active last timestep (as most weights usually are) then
    // we can do it the easy way.
    cftype dot = 0.0;
    for (int k=0; k<ny; k++) dot += matC[k] * eli[k];
    weights[i] += lasterror * dot;
    // update eligibility (decay over one timestep)
    if (!old) MatrixMultiply2 (ny,neweli,matA,eli);
  }
  else {
    // Otherwise we must use the trace object. Remember that the trace object
    // thinks the current time is one timestep ago.
    cftype dot,*delta;
    delta = trace->GetDelta (t);
    dot = 0.0;
    for (int k=0; k<ny; k++) dot += delta[k] * eli[k];
    weights[i] += dot;
    // update eligibility (decay over the interval)
    if (!old) atoi->PreMulByAtoPlus (neweli,eli,time-t);
  }

  aux[i] = -1;
  return neweli;
}
