""" Code to average FITS frames

Context : SRP
Module  : SRPAdvAverage.py
Version : 1.2.3
Author  : Stefano Covino
Date    : 12/10/2011
E-mail  : stefano.covino@brera.inaf.it
URL:    : http://www.merate.mi.astro.it/~covino
Purpose : Manage the average of frame FITS files.

Usage   : SRPAdvAverage [-v] [-h] -i arg1 -o arg2 [-s arg3] [-x arg4]
            -i Input FITS file list
            -s Sigma-clipping level
            -x Input FITS exposure map file list
            -o Output FITS file

            The exposure maps, if available, allow to compensate areas less exposed.

History : (13/11/2008) First version.
        : (16/11/2008) Management of different exposure times.
        : (24/04/2009) Management of lacking keywords.
        : (11/09/2009) Minor correction.
        : (20/08/2010) New sigma clipping average and exposure maps.
        : (14/10/2010) Better import style.
        : (07/08/2011) Better cosmetics.
        : (12/10/2011) Bug correction in Fits header reading.
"""



import os, os.path, string, types
from optparse import OptionParser
import SRP.SRPConstants as SRPConstants
import SRP.SRPFiles as SRPFiles
import SRP.SRPUtil as SRPUtil
import SRP.SRPAstro as SRPAstro
import numpy
import SRP.stats as stats
import pyfits
from SRP.SRPStatistics.AverIterSigmaClipp import AverIterSigmaClipp


parser = OptionParser(usage="usage: %prog [-v] [-h] -i arg1 -o arg2 [-s arg3] [-x arg4]", version="%prog 1.2.3")
parser.add_option("-i", "--inputlist", action="store", nargs=1, type="string", dest="fitsfilelist", help="Input FITS file list")
parser.add_option("-s", "--sigmaclip", action="store", nargs=1, type="float", dest="sigmaclip", help="Sigma clipping level")
parser.add_option("-v", "--verbose", action="store_true", dest="verbose", help="Fully describe operations")
parser.add_option("-x", "--expmaplist", action="store", nargs=1, type="string", dest="expmaplist", help="Input FITS exposure map file list")
parser.add_option("-o", "--outfile", action="store", nargs=1, type="string", dest="outfitsfile", help="Output FITS file")
(options, args) = parser.parse_args()


if options.fitsfilelist and options.outfitsfile:
    sname = SRPFiles.getSRPSessionName()
    if options.verbose:
        print "Session name %s retrieved." % sname
    if os.path.isfile(options.fitsfilelist):
        f = SRPFiles.SRPFile(SRPConstants.SRPLocalDir,options.fitsfilelist,SRPFiles.ReadMode)
        f.SRPOpenFile()
        if options.verbose:
            print "Input FITS file list is: %s." % options.fitsfilelist
    #
        if options.expmaplist:
            if os.path.isfile(options.expmaplist):
                xf = SRPFiles.SRPFile(SRPConstants.SRPLocalDir,options.expmaplist,SRPFiles.ReadMode)
                xf.SRPOpenFile()
                if options.verbose:
                    print "Exposure map file list is: %s." % options.expmaplist
    #
        flist = []
        nentr = 0
    #
        if options.expmaplist:
            xflist = []
            xnentr = 0
    #
        while True:
            dt = f.SRPReadFile()
#
            if options.expmaplist:
                xdt = xf.SRPReadFile()
#
            if dt != '':
                flist.append(string.split(string.strip(dt))[0])
                nentr = nentr + 1
                if not os.path.isfile(flist[nentr-1]):
                    parser.error("Input FITS file %s not found" % flist[nentr-1])
                if options.verbose:
                    print "FITS file selected: %s" % string.split(string.strip(dt))[0]
            else:
                break
#
            if options.expmaplist:
                if xdt != '':
                    xflist.append(string.split(string.strip(xdt))[0])
                    xnentr = xnentr + 1
                    if not os.path.isfile(xflist[xnentr-1]):
                        parser.error("Input exposure map file %s not found" % xflist[xnentr-1])
                    if options.verbose:
                        print "Exposure map file selected: %s" % string.split(string.strip(xdt))[0]
                else:
                    break
    #
        f.SRPCloseFile()
    #
        if options.expmaplist:
            xf.SRPCloseFile()
    #
        if options.expmaplist:
            if nentr != xnentr:
                parser.error("FITS file list and exposure maps do not correspond.")
    #
        if options.sigmaclip:
            if options.sigmaclip <= 0.0:
                parser.error("Sigma clipping parameter must be positive.")
    #
        if options.verbose:
            print "Computing average..."
        tdata = []
    #
        if options.expmaplist:
            xtdata = []
    #
        thead = []
        shapex = 1e6
        shapey = 1e6
        etime = []
        obstm = []
        for i in flist:
            hdr = pyfits.open(i)
            tdata.append(hdr[0].data)
            thead.append(hdr[0].header)
            try:
                etime.append(hdr[0].header['EXPTIME'])
            except KeyError:
                etime.append(1.0)
            try:
                obstm.append(hdr[0].header['MJD-OBS'])
            except KeyError:
                obstm.append(0.0)
            if type(obstm[-1]) != types.FloatType:
                obstm[-1] = 0.0
            #
            shape = hdr[0].data.shape
            if shape[0] < shapey:
                shapey = shape[0]
            if shape[1] < shapex:
                shapex = shape[1]
    #
        if options.expmaplist:
            for i in xflist:
                xhdr = pyfits.open(i)
                xtdata.append(xhdr[0].data)
    #
#               print shapex, shapey
        newdata = numpy.zeros((shapey,shapex))
    #
        if options.expmaplist:
            xnewdata = numpy.zeros((shapey,shapex))
    #
        tard = numpy.array([tdata[i][:shapey,:shapex] for i in range(len(tdata))])
    #
        if options.expmaplist:
            xtard = numpy.array([xtdata[i][:shapey,:shapex] for i in range(len(xtdata))])
    #
        tottime = stats.sum(etime)
    #
        if options.sigmaclip:
            for l in range(shapex):
                if options.verbose:
                    pcr = 100.0*l/shapex
                    if pcr % 10 < 0.1:
                        print "Job completed: %.1f%%" % (100.0*l/shapex)
                for m in range(shapey):
                    pixm = AverIterSigmaClipp([tard[i,m,l] for i in range(len(tdata))],options.sigmaclip)[0]
                    newdata[m,l] = pixm
        else:
            for i in range(len(tdata)):
                tempdata = numpy.multiply(tdata[i][:shapey,:shapex],etime[i]/tottime)
                newdata = numpy.add(newdata,tempdata)
#
            if options.expmaplist:
                for i in range(len(xtdata)):
                    xtempdata = numpy.multiply(xtdata[i][:shapey,:shapex],etime[i]/tottime)
                    xnewdata = numpy.add(xnewdata,xtempdata)
#
            if options.expmaplist:
                newdata = numpy.divide(newdata,xnewdata)
            else:
                newdata = numpy.multiply(newdata,1.0)
        if options.verbose:
            print "Saving average file: %s" % sname+options.outfitsfile
    #
        if options.expmaplist and options.verbose:
            frot,frxt = os.path.splitext(options.outfitsfile)
            print "Saving average exposure map: %s" % sname+frot+SRPConstants.SRPExpMap+frxt
    #
        nfts = pyfits.PrimaryHDU(newdata,thead[0])
        nftlist = pyfits.HDUList([nfts])
        nftlist[0].header.update('hierarch '+SRPConstants.SRPCategory,SRPConstants.SRPSCIENCE,SRPConstants.SRPCatComm)
        nftlist[0].header.update('hierarch '+SRPConstants.SRPNFiles,len(tdata),SRPConstants.SRPNFilesComm)
        if options.sigmaclip:
            nftlist[0].header.update('hierarch '+SRPConstants.SRPMethod,SRPConstants.SRPAVERAGESC % options.sigmaclip,SRPConstants.SRPMethodComm)
        else:
            nftlist[0].header.update('hierarch '+SRPConstants.SRPMethod,SRPConstants.SRPAVERAGE,SRPConstants.SRPMethodComm)
        nftlist[0].header.update('EXPTIME',stats.mean(etime))
#               tottime = stats.sum(etime)
        if options.verbose:
            print "Total observing time: %.1fs" % tottime
        for i in range(len(obstm)):
            obstm[i] = obstm[i]*etime[i]/tottime
        nftlist[0].header.update('MJD-OBS',stats.sum(obstm))
        nftlist.writeto(sname+options.outfitsfile,clobber=True)
        if options.expmaplist:
            xnfts = pyfits.PrimaryHDU(xnewdata,thead[0])
            xnftlist = pyfits.HDUList([xnfts])
            xnftlist.writeto(sname+frot+SRPConstants.SRPExpMap+frxt,clobber=True)
    else:
        parser.error("Input FITS file list %s not found" % options.fitsfilelist)
else:
    parser.print_help()
