/*****************************************************************************
 *                                                                           *
 * FOX Controller Library, Version 1.3.3.                                    *
 *                                                                           *
 * Copyright (C) 1998,1999,2000 Russell Smith (rl.smith@auckland.ac.nz)      *
 *                                                                           *
 * The FOX Controller Library is free software; you can redistribute it      *
 * and/or modify it under the terms of the GNU Library General Public        *
 * License as published by the Free Software Foundation; either version      *
 * 2 of the License, or (at your option) any later version.                  *
 *                                                                           *
 * The FOX Controller Library is distributed in the hope that it will be     *
 * useful, but WITHOUT ANY WARRANTY; without even the implied warranty of    *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU         *
 * Library General Public License for more details.                          *
 *                                                                           *
 * You should have received a copy of the GNU Library General Public         *
 * License along with the Fox Controller Library; see the file COPYING.LIB.  *
 * If not, write to the Free Software Foundation, Inc., 59 Temple Place -    *
 * Suite 330, Boston, MA 02111-1307, USA.                                    *
 *                                                                           *
 *****************************************************************************/

// Test of the FOX-N algorithm on simple linear systems.

#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <math.h>

#include "fox-n.h"


cftype frandom()
{
  return cftype(random())/cftype(RAND_MAX);
}


void PlotData (int n, float *y)
{
  for (int i=0; i<n; i++) {
    printf ("%.8e ",y[i]);
  }
  printf ("\n");
}


// system parameters
const cftype ka=2;
const cftype kb=1;
const cftype kh=0.1;		// step size
cftype Adata[4] = {1,kh,-ka*kh,1-kb*kh};
cftype Bdata[2] = {0,kh*ka};
cftype Cdata[2] = {1,0.1};

// simulation parameters
const int T=400;		// total number of time steps
CmacInputInfo cmac_input_info[1] = {{T,0,T*kh}};

// misc parameters
const float yscale = 5.0;


int main (int argc, char *argv[])
{
  cftype yref[T];

  // initialize random numbers
  //srandom(1234567);
  srandom(time(NULL));

  // make random-ish yref trajectory to follow
  for (int i=0; i<T; i++) yref[i] = 0;
  for (int i=0; i<10; i++) {
    cftype xc = frandom()*cftype(T)*kh;
    cftype gain = 10*(frandom()-0.5);
    for (int j=0; j<T; j++)
      yref[j] += 1.0/(1.0+exp(-((cftype(j)*kh)-xc)*gain));
  }
  // starts at 0
  for (int i=1; i<T; i++) yref[i] -= yref[0];
  yref[0] = 0;

  // Plot reference trajectory and grid
  PlotData (T,yref);

  // setup fox
  const int num_inputs = 1;
  const int num_weights = 50000;
  const int history = 100; //@@@
  const int C = 10;
  const int ny = 2;
  FoxN fox (num_inputs,cmac_input_info,num_weights,C,ny,
	    Adata,Bdata,Cdata,history);
  fox.SetEligibilityMode (1);

  // setup storage
  cftype *this_y0 = new cftype [T];
  cftype *this_y1 = new cftype [T];
  cftype *save_y0 = new cftype [T];
  cftype *save_y1 = new cftype [T];
  cftype *this_x = new cftype [T];
  cftype *save_x = new cftype [T];
  for (int i=0; i<T; i++) this_y0[i] = this_y1[i] = 0;
  for (int i=0; i<T; i++) save_y0[i] = save_y1[i] = 0;
  for (int i=0; i<T; i++) save_x[i] = 0;

  for (int iteration=0; iteration<20000; iteration++) {
    cftype y[2];
    y[0] = -3;
    y[1] = 0;
    fox.Reset();

    for (int i=0; i<T; i++) {
      cftype input = cftype(i)*kh + kh/2.0;
      fox.Map (&input);
      // cftype x = yref[i] + fox.Output()
      cftype x = fox.Output();

      this_y0[i] = y[0];
      this_y1[i] = y[1];
      this_x[i] = x;

      cftype learning_rate = 0.005;
      // fox.LimitOutput (learning_rate * 0.003);
      fox.LimitOutputDerivative (learning_rate * 0.01);

      static cftype old_pos_error = 0;
      cftype pos_error = yref[i] - y[0];
      cftype error = Cdata[0]*(yref[i] - y[0]) - Cdata[1]*y[1];
      if (fabs(pos_error) > fabs(old_pos_error))
	fox.Train (error * learning_rate);
      else
	fox.Train (error * learning_rate * 0.1);
      old_pos_error = pos_error;


      cftype newy[2];
      newy[0] = Adata[0]*y[0] + Adata[1]*y[1] + Bdata[0]*x;
      newy[1] = Adata[2]*y[0] + Adata[3]*y[1] + Bdata[1]*x;
      y[0] = newy[0];
      y[1] = newy[1];
    }

    // plot the results
    if ((iteration % 10)==0) {
      fprintf (stderr,"%d\n",iteration);

      PlotData (T,this_y0);
      PlotData (T,this_y1);
      PlotData (T,this_x);

      cftype *tmp = this_y0;
      this_y0 = save_y0;
      save_y0 = tmp;

      tmp = this_y1;
      this_y1 = save_y1;
      save_y1 = tmp;

      tmp = this_x;
      this_x = save_x;
      save_x = tmp;
    }
  }

  return 0;
}
