/**************************************************************************
 **       Title: demo program for linear solver test
 **        Date: 8.1.2009
 **   Copyright: Bernard Haasdonk
 **************************************************************************/

#include"config.h"
#include<dune/common/mpihelper.hh> // include mpi helper class 
#include<dune/grid/io/file/dgfparser.hh>
#include <dune/fem/function/adaptivefunction/adaptivefunction.hh>
#include <dune/fem/space/dgspace.hh>
#include <dune/grid/common/gridpart.hh>
#include<dune/grid/sgrid.hh> // load sgrid definition
#include<dune/fem/solver/oemsolver/oemsolver.hh> 

// collect some consts and typedefs in order to keep the main-program readable
const int dim = 2;
const int dimworld = 2;
const int polOrder = 0;
const int dimrange = 1;

typedef Dune::SGrid<dim,dimworld> GridType;

typedef double DomainFieldType;
typedef double RangeFieldType;
typedef Dune::FunctionSpace<DomainFieldType, RangeFieldType, 
                            dimworld,dimrange> FunctionSpaceType;

typedef Dune::LeafGridPart<GridType> GridPartType;
typedef Dune::DiscontinuousGalerkinSpace<FunctionSpaceType, GridPartType, 
                                         polOrder>
              DiscreteFunctionSpaceType;

typedef Dune::AdaptiveDiscreteFunction<DiscreteFunctionSpaceType> 
        DiscreteFunctionType;

typedef DiscreteFunctionType::DofIteratorType DofIteratorType;

class MyOperatorType;

//typedef Dune::OEMCGOp<DiscreteFunctionType,MyOperatorType>  
//        InverseOperatorType;
//typedef Dune::OEMBICGSTABOp<DiscreteFunctionType,MyOperatorType>  
//        InverseOperatorType;
typedef Dune::OEMBICGSQOp<DiscreteFunctionType,MyOperatorType>  
        InverseOperatorType;
//typedef Dune::OEMGMRESOp<DiscreteFunctionType,MyOperatorType>  
//        InverseOperatorType;

class MyOperatorType
//: 
//    public Dune::Operator<double, double, 
//    DiscreteFunctionType, DiscreteFunctionType>
{  
public:  
  MyOperatorType(int n): A_(0), size_(n)  
        {
          A_ = new double* [n];
          for (int j=0; j<n;j++)
          {
            A_[j] = new double[n];
            double* elptr = A_[j];
            for (int i=0; i<n;i++, elptr++)
                *elptr = 0.0;
          }
        }

// virtual void operator()(const DiscreteFunctionType &arg, 
//                          DiscreteFunctionType &dest) const
//        {
//          std:: cout << "operator() is called, better implement it! \n";
//        }
  
//  virtual void apply(const DiscreteFunctionType& arg, 
//                     DiscreteFunctionType & dest)
//        {
//          std:: cout << "apply() is called, better implement it! \n";
//        }

//  void multOEM(const DiscreteFunctionType& arg, 
//               DiscreteFunctionType & dest)
  void multOEM(const double*& arg, 
               double*& dest) const
        {
          std::cout << "entered oemMult \n";
          
          double** Arow=A_; 
          double* d=dest;
          for (int j=0; 
               j<size_;
               j++, Arow++,d++)
          {
            double val = 0.0;
            double* el=*Arow;
            const double* v = arg;
            for (int i=0; 
                 i<size_;
                 i++, el++, v++)
            {              
//              std::cout << "i = " << i << ", j = " << j << "\n";                
              val+= (*el) * (*v);
            }
            
            *d = val;
          }
          std::cout << "finished oemMult \n";
          
        }  
  
  MyOperatorType& systemMatrix()
        {
          return *this;
        }
  
  
  double** matrix()
        {
          return A_;
        }

  ~MyOperatorType()
        {
          for (int j=0; j<size_;j++)
              delete[] A_[j];
          delete[] A_; 
        }
  
private:
  double** A_;
  int size_;
};

// main program generating a discrete function and plotting the data
int main(int argc, char **argv)
{ 
  // initialize MPI, finalize is done automatically on exit 
  Dune::MPIHelper::instance(argc,argv);
    
  Dune::FieldVector<int,dim> N(2);;                    /*@\label{gs:par0}@*/
  N[1] = 1;
  Dune::FieldVector<GridType::ctype,dim> L(-1.0);
  Dune::FieldVector<GridType::ctype,dim> H(1.0);;      /*@\label{gs:par1}@*/
  GridType grid(N,L,H);                                /*@\label{gs:grid}@*/
  
  // generate Lagrange function on LeafGridPart
  GridPartType gridpart(grid);
  DiscreteFunctionSpaceType dfspace(gridpart);
  
  int n = dfspace.size();
  std::cout << "size of df-space: " << n << "\n";
  
  // generate p0 function and set DOF-indices as function value
  DiscreteFunctionType b("b",dfspace);
  for (DofIteratorType it=b.dbegin(); it!=b.dend(); ++it)
      *it = 1.0; 
    
  // generate p0 function and set DOF-indices as function value
  DiscreteFunctionType x("x",dfspace);
  for (DofIteratorType it=x.dbegin(); it!=x.dend(); ++it)
      *it = 1.0; 
  
  // set matrix A as mapping
  MyOperatorType myOp(dfspace.size());
  double** A = myOp.matrix();
  
  for (int sys=0;sys<4;sys++)
  {    
    x.clear();
    
   // try to use dune-fem solvers
    double redEps = 0.0;
    double absLimit = 1e-15;
    int maxIter = 20000;
    bool verbose = true;
    InverseOperatorType
        solver(myOp, redEps, absLimit, maxIter, verbose);
    
    switch(sys)
    {
      
    case 0:
      std::cout << "----------------\nindefinite, non-symmetric: \n";
      A[0][0] = 2.0;
      A[0][1] = 6.0;
      A[1][0] = 4.0;
      A[1][1] = 2.0;
      break;
      
    case 1:
      std::cout << "----------------\nindefinite, symmetric: \n";
      
      A[0][0] = 2.0;
      A[0][1] = 4.0;
      A[1][0] = 4.0;
      A[1][1] = 2.0;
      break;
      
    case 2:
      
      std::cout << "----------------\npositive definite, non-symmetric: \n";
      
      A[0][0] = 2.0;
      A[0][1] = 2.0;
      A[1][0] = 0.0;
      A[1][1] = 4.0;
      break;
      
    case 3:
      
      std::cout << "----------------\npositive definite, symmetric: \n";
      
      A[0][0] = 2.0;
      A[0][1] = 1.0;
      A[1][0] = 1.0;
      A[1][1] = 2.0;
      break;
      
    }  
    solver(b,x);
    // output solution vector:
    std::cout << "x = \n";
    for (DofIteratorType it=x.dbegin(); it!=x.dend(); ++it)
        std::cout << *it << " ";
    std::cout << "\n";
    

  } // end sys loop
  
  
//  for (int i = 0; i<n; i++)
//      A[i][i] = 2 * (i+1);

      
}