/*****************************************************************************
 *                                                                           *
 * FOX Controller Library, Version 1.3.3.                                    *
 *                                                                           *
 * Copyright (C) 1998,1999,2000 Russell Smith (rl.smith@auckland.ac.nz)      *
 *                                                                           *
 * The FOX Controller Library is free software; you can redistribute it      *
 * and/or modify it under the terms of the GNU Library General Public        *
 * License as published by the Free Software Foundation; either version      *
 * 2 of the License, or (at your option) any later version.                  *
 *                                                                           *
 * The FOX Controller Library is distributed in the hope that it will be     *
 * useful, but WITHOUT ANY WARRANTY; without even the implied warranty of    *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU         *
 * Library General Public License for more details.                          *
 *                                                                           *
 * You should have received a copy of the GNU Library General Public         *
 * License along with the Fox Controller Library; see the file COPYING.LIB.  *
 * If not, write to the Free Software Foundation, Inc., 59 Temple Place -    *
 * Suite 330, Boston, MA 02111-1307, USA.                                    *
 *                                                                           *
 *****************************************************************************/

#include <string.h>
#include <math.h>

#include "atoi.h"

//***************************************************************************
// Matrix inversion

class Matrix {
  int n;		// matrix is n*n - n=0 if constructor failed
  int size;		// size = n*n
  cftype *data;		// n*n matrix elements

  int LUDecompose (int perm[]);
  void Solve (cftype *b, int perm[]);

public:
  Matrix (int _n, cftype *_data);
  void SetInverse (Matrix &a);
};

//...........................................................................
// Debugging macros to test pointer bounds

// #define CHECKIT

#ifdef CHECKIT

#define CHECK(pointer) if ((pointer)<data || (pointer)>=(data+size)) \
  ZFatalError ("Pointer check failed at " __FILE__ ":%d",__LINE__);
#define CHECK2(pointer,obj,limit) \
  if ((pointer)<(obj) || (pointer)>=((obj)+limit)) \
    ZFatalError ("Pointer check failed at " __FILE__ ":%d",__LINE__);
#define CHECK3(num) \
  if ((num)<0 || (num)>=n) \
    ZFatalError ("Range check failed at " __FILE__ ":%d",__LINE__);
#define CHECK_MESSAGE printf ("Matrix checking turned on\n");

#else

#define CHECK(pointer) /**/
#define CHECK2(pointer,obj,limit) /**/
#define CHECK3(num) /**/
#define CHECK_MESSAGE /**/

#endif

//...........................................................................

// This is based on the ludcmp routine from NRC #2.3.

int Matrix::LUDecompose (int perm[])
{
  int i,imax,j,k,swaps;		// swaps = num row interchanges
  cftype big,dum,sum,temp;
  cftype *Aptr,*rptr,*cptr;	// row & col pointers for Crout's method

  if (n==0) return 0;

  cftype vv [n];		// vv = the implicit scaling of each row
  swaps = 1;
  imax = 0;			// to prevent "possibly unused" warning

  // loop over rows to get the implicit scaling information
  Aptr = data;
  for (i=0; i<n; i++) {
    big = 0.0;
    for (j=n; j; j--) {
      CHECK(Aptr)
      if ((temp = fabs(*Aptr++)) > big) big = temp;
    }
    checkthat2 (big != 0.0,"Singular matrix in Matrix::LUDecompose");
    CHECK3(i)
    vv[i] = 1.0/big;
  }

  for (j=0; j<n; j++) {		// loop over columns
    Aptr = data+j;
    for (i=0; i<j; i++) {
      CHECK(Aptr)
      sum = *Aptr;
      rptr = data + n*i;	// refs elements in a row
      cptr = data + j;		// refs elements in a column
      for (k=i; k; k--) {
        CHECK(rptr)
        CHECK(cptr)
        sum -= (*rptr++) * (*cptr);
        cptr += n;		// jump to next row
      }
      CHECK(Aptr)
      *Aptr = sum;
      Aptr += n;		// jump to next row
    }

    big = 0.0;
    for (i=j; i<n; i++) {
      CHECK(Aptr)
      sum = *Aptr;
      rptr = data + n*i;	// refs elements in a row
      cptr = data + j;		// refs elements in a column
      for (k=j; k; k--) {
        CHECK(rptr)
        CHECK(cptr)
        sum -= (*rptr++) * (*cptr);
        cptr += n;              // jump to next row
      }
      CHECK(Aptr)
      *Aptr = sum;
      // if the pivot figure of merit better than the current best...
      CHECK3(i)
      if ((dum=vv[i]*fabs(sum)) >= big) {
        big=dum;
        imax=i;
      }
      Aptr += n;		// jump to next row
    }

    if (j != imax) {		// do we need to swap rows?
      Aptr = data + n*imax;
      rptr = data + n*j;
      for (k=n; k; k--) {	// row swap
        CHECK(Aptr)
        CHECK(rptr)
        dum = *Aptr;
	*Aptr++ = *rptr;
	*rptr++ = dum;
      }
      swaps=-swaps;		// record row swap parity
      CHECK3(imax)
      CHECK3(j)
      vv[imax] = vv[j];		// also interchange scale factor
    }
    CHECK3(j)
    perm[j] = imax;

    Aptr = data + (n+1)*j;	// Aptr = pivot element
    CHECK(Aptr)
    if ((*Aptr) == 0.0) {
      *Aptr = 1.0e-20;		// Cheating!
    }

    if (j != (n-1)) {		// divide by the pivot element
      dum = 1.0/(*Aptr);
      Aptr += n;
      for (i=j+1; i<n; i++) {
        CHECK(Aptr)
        (*Aptr) *= dum;
        Aptr += n;
      }
    }
  }

  return swaps;
}


// This is based on the lubksb routine from NRC #2.3.

void Matrix::Solve (cftype *b, int perm[])
{
  int i,ii=-1,ip,j;
  cftype sum, *Aptr, *bptr;

  if (n==0) return;

  // forward substitution, taking account of permutation
  for (i=0; i<n; i++) {		// when ii>=0, ii=index of 1st nonzero
    ip = perm[i];		// element of b
    sum = b[ip];
    b[ip] = b[i];
    if (ii>=0) {
      Aptr = data + i*n+ii;
      bptr = b + ii;
      for (j=ii; j<i; j++) {
        CHECK(Aptr)
        CHECK2(bptr,b,n)
        sum -= (*Aptr++) * (*bptr++);
      }
    }
    else if (sum != 0) ii=i;		// nonzero found, adjust ii
    b[i] = sum;
  }

  // backward substitution
  b[n-1] /= data[size-1];	// handle last element specially
  if (n > 1) {
    for (i=n-2; i>=0; i--) {
      sum = b[i];
      Aptr = data + (n+1)*i+1;
      bptr = b + i + 1;
      for (j=i+1; j<n; j++) {
        CHECK(Aptr)
        CHECK2(bptr,b,n)
        sum -= (*Aptr++) * (*bptr++);
      }
      CHECK(data+(n+1)*i)
      b[i] = sum/data[(n+1)*i];	// store component of solution 'x'
    }
  }
}


Matrix::Matrix (int _n, cftype *_data)
{
  n = _n;
  size = n*n;
  data = _data;
  if (n < 1) n=0;
  CHECK_MESSAGE
}


void Matrix::SetInverse (Matrix &a)
{
  int i,j;
  cftype *Aptr, *bptr;

  if (n==0 || &a == this || a.n != n) return;

  int perm[n];
  a.LUDecompose (perm);

  cftype b [n];
  for (j=0; j<n; j++)	 {	// find inverse by columns
    // make b into j'th column of identity
    for (i=0; i<n; i++) b[i]=0.0;
    b[j] = 1.0;
    a.Solve (b,perm);
    Aptr = data+j;
    bptr = b;
    for (i=n; i; i--) {
      CHECK(Aptr)
      CHECK2(bptr,b,n)
      *Aptr = *bptr++;
      Aptr += n;
    }
  }
}

//...........................................................................
// Test

/*

#include <stdio.h>
#include <stdlib.h>


cftype frandom()
{
  return cftype(random()) / cftype(RAND_MAX);
}


int main()
{
  cftype Adata[16],Ainvdata[16];
  Matrix A(4,Adata),Ainv(4,Ainvdata);

  srandom (time(NULL));
  for (int i=0; i<16; i++) Adata[i] = frandom();

  FILE *f = fopen ("z1","w");
  for (int i=0; i<4; i++) {
    for (int j=0; j<4; j++) fprintf (f,"%.15f ",Adata[i*4+j]);
    fprintf (f,"\n");
  }
  fclose (f);

  Ainv.SetInverse (A);

  f = fopen ("z2","w");
  for (int i=0; i<4; i++) {
    for (int j=0; j<4; j++) fprintf (f,"%.15f ",Ainvdata[i*4+j]);
    fprintf (f,"\n");
  }
  fclose (f);
}

*/

//***************************************************************************
// GetAtoi


GetAtoi::GetAtoi (int _ny, cftype *A, int _n)
{
  int i,j,k,l;

  valid = 0;
  Atoi = NULL;
  invAtoi = NULL;

  ny = _ny;
  ny2 = ny*ny;
  n = _n;
  checkthat (ny <= MAX_NY,"ny > %d in GetAtoi ctor",MAX_NY);
  checkthat (n >= 1,"bad n in Atoi ctor");
  checkmemptr (Atoi = new cftype [ny2 * n]);
  checkmemptr (invAtoi = new cftype [ny2 * n]);

  // get inverse of A
  cftype Adata[ny2],invAdata [ny2];
  memcpy (Adata,A,ny2*sizeof(cftype));
  Matrix matA (ny,Adata);
  Matrix invA (ny,invAdata);
  invA.SetInverse (matA);

  // get X1 = X2 = identity
  cftype X1 [MAX_NY * MAX_NY];
  cftype X2 [MAX_NY * MAX_NY];
  for (i=0; i<ny2; i++) X1[i] = X2[i] = 0.0;
  for (i=0; i<ny; i++) X1[(ny+1)*i] = X2[(ny+1)*i] = 1.0;

  // make the first entries in `Atoi' and `invAtoi'  - the identity matrix
  for (i=0; i<ny2; i++) Atoi[i] = invAtoi[i] = 0.0;
  for (i=0; i<ny; i++) Atoi[(ny+1)*i] = invAtoi[(ny+1)*i] = 1.0;

  cftype sum;
  cftype *oldp1,*p1 = Atoi + ny2;
  cftype *oldp2,*p2 = invAtoi + ny2;
  for (i=1; i<n; i++) {
    oldp1 = p1;
    oldp2 = p2;

    // postmultiply X1 by A to get the next entry in `invAtoi'
    for (j=0; j<ny; j++) {
      for (k=0; k<ny; k++) {
	sum = 0;
	for (l=0; l<ny; l++) sum += X1[j*ny+l] * A[l*ny+k];
	*(p1++) = sum;
      }
    }

    // postmultiply X2 by inverse of A to get the next entry in `invAtoi'
    for (j=0; j<ny; j++) {
      for (k=0; k<ny; k++) {
	sum = 0;
	for (l=0; l<ny; l++) sum += X2[j*ny+l] * invAdata[l*ny+k];
	*(p2++) = sum;
      }
    }

    // suck the results back into X1 and X2
    for (j=0; j<ny2; j++) X1[j] = *(oldp1++);
    for (j=0; j<ny2; j++) X2[j] = *(oldp2++);
  }

  valid = 1;

  // print the results
  //  for (i=0; i<n; i++) {
  //    for (j=0; j<ny; j++) {
  //      for (k=0; k<ny; k++) {
  //        printf ("%12.4f ",Atoi[i*ny2 + j*ny + k]);
  //      }
  //      printf ("\n");
  //    }
  //    printf ("\n");
  //  }
  //  exit(1);
}


GetAtoi::~GetAtoi()
{
  delete[] Atoi;
  delete[] invAtoi;
}
