#include "homodyne.h"
#include "data.h"
#include "controller.h"

#include <odindata/fitting.h>

template<recoDim interpolDim, int interpolIndex, int orthoIndex>
bool RecoHomodyne<interpolDim,interpolIndex,orthoIndex>::process(RecoData& rd, RecoController& controller) {
  Log<Reco> odinlog(c_label(),"process");

  Range all=Range::all();

  ComplexData<3>& data=rd.data(Rank<3>());
  TinyVector<int,3> shape=data.shape();
  int nlines=shape(interpolIndex);

  if(rd.interpolated) measlines.update(*(rd.interpolated)); // take previously interpolated coords into account

  ivector indexvec=measlines.get_indices(rd.coord());

  int startindex=indexvec.minvalue();
  float center=0.5*float(nlines-1);
  int endindex=nlines-1-startindex;
  ODINLOG(odinlog,normalDebug) << "startindex/endindex/center=" << startindex << "/" << endindex << "/" << center << STD_endl;

  if(float(startindex)<center) { // True homodyne reco
    ODINLOG(odinlog,normalDebug) << "Homodyne reco"<< STD_endl;

    ComplexData<3> symmetric(shape); symmetric=STD_complex(0.0);

    if(startindex>0) { // Don't do this for fully-sampled k-space

      // symmetric part
      Range symrange(startindex,endindex);
      if(interpolIndex==0) symmetric(symrange,all,all)=data(symrange,all,all);
      if(interpolIndex==1) symmetric(all,symrange,all)=data(all,symrange,all);
      symmetric.fft();

      // pre weighting
      fvector transition(endindex-startindex+1);
      transition.fill_linear(0.0,1.0);
      ComplexData<1> weight(shape(interpolIndex));
      weight=STD_complex(1.0);
      TinyVector<int,3> index;
      for(int i=0; i<startindex; i++) weight(i)=STD_complex(0.0); // make sure we blank all non-acquired lines
      for(int i=0; i<int(transition.size()); i++) weight(startindex+i)=transition[i];
      for(int iortho=0; iortho<shape(orthoIndex); iortho++) {
        index(orthoIndex)=iortho;
        for(int ipol=0; ipol<shape(interpolIndex); ipol++) {
          index(interpolIndex)=ipol;
          for(int iread=0; iread<shape(2); iread++) {
            index(2)=iread;
            data(index)*=weight(ipol);
          }
        }
      }
    }


    // do the FFT
    data.fft();


    if(startindex>0) { // Don't do this for fully-sampled k-space

      // phase correction
      data*=expc(float2imag(-phase(symmetric)));

      // take real part
      ComplexData<3> realpart(shape);
      realpart=float2real(creal(data));
      data.reference(realpart);
    }


  } else if(fabs(float(startindex)-center)<=1.0) {  // Just mirror k-space
    ODINLOG(odinlog,significantDebug) << "Conjugate-mirror reco with startindex=" << startindex << STD_endl;

    if(interpolDim!=line) {
      ODINLOG(odinlog,errorLog) << "Not implemented: Conjugate-mirror reco in dimension " << recoDimLabel[interpolDim] << STD_endl;
      return false;
    }


    int nread=shape(2);


    // Apply phase correction using the central line

    data.partial_fft(TinyVector<bool,3>(false,false,true)); // FFT in read

    ComplexData<1> centline(data(shape(0)/2,startindex,all));

    Data<float,1> centphase(centline.phasemap());
    Data<float,1> centmagn(cabs(centline));
    Data<float,1> centerr(nread);
    for(int iread=0; iread<nread; iread++) centerr(iread)=secureInv(pow(centmagn(iread),2));

    LinearFunction linf;
    linf.fit(centphase,centerr);

    for(int iread=0; iread<nread; iread++) {
      data(all,all,iread)*=expc(STD_complex(0.0, -(iread*linf.m.val+linf.c.val)));
    }

    data.partial_fft(TinyVector<bool,3>(false,false,true),false); // inv FFT in read


    ComplexData<1> oneline(nread);
    int nmeas=nlines-startindex;
    for(int iphase3d=0; iphase3d<shape(0); iphase3d++) {
      for(int iline=1; iline<nmeas; iline++) {

        // mirror remaining lines
        int isrc=startindex+iline;
        int idst=startindex-iline;
        if(idst>=0) {
          oneline=conjc(data(iphase3d,isrc,all).reverse(0));
          if(!(nread%2)) oneline.shift(0,1); // for even nread, the center frequency is at nread/2
          data(iphase3d,idst,all)=oneline;
        }
      }
    }

    if(!(nread%2)) data(all,0,all)=STD_complex(0.0); // zero-fill garbage from cyclic shift

    // do the FFT
    data.fft();

/*
    Data<float,3>(creal(data)).autowrite("creal_"+rd.coord().print(RecoIndex::filename)+".jdx");
    Data<float,3>(cimag(data)).autowrite("cimag_"+rd.coord().print(RecoIndex::filename)+".jdx");
    Data<float,3>(cabs(data)).autowrite("cabs_"+rd.coord().print(RecoIndex::filename)+".jdx");
    Data<float,3>(phase(data)).autowrite("phase_"+rd.coord().print(RecoIndex::filename)+".jdx");
*/


  } else {
    ODINLOG(odinlog,errorLog) << "startindex=" << startindex << " larger than center=" << center << STD_endl;
    return false;
  }

  return execute_next_step(rd,controller);
}

///////////////////////////////////////////////////////


template<recoDim interpolDim, int interpolIndex, int orthoIndex>
bool RecoHomodyne<interpolDim,interpolIndex,orthoIndex>::query(RecoQueryContext& context) {
  Log<Reco> odinlog(c_label(),"query");
  if(context.mode==RecoQueryContext::prep) {
    if(!measlines.init(context.coord, context.controller)) return false;
  }
  return RecoStep::query(context);
}
