#!/usr/bin/env python

# Copyright (C) 2011 Ian W. Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Aligned spin bank sub-program.
"""

from __future__ import division
import matplotlib
matplotlib.use('Agg')
import pylab
import time
start = int(time.time()*10**6)
elapsed_time = lambda: int(time.time()*10**6-start)

import os,sys,optparse,copy
import numpy
from pylal import geom_aligned_bank_utils,git_version

__author__  = "Ian Harry <ian.harry@astro.cf.ac.uk>"
__version__ = "git id %s" % git_version.id
__date__    = git_version.date


# Feed in command line options
usage = """usage: %prog [options]"""
_desc = __doc__[1:]
parser = optparse.OptionParser(usage, version=__version__, description=_desc)
parser.add_option("-v", "--verbose", action="store_true", default=False,\
                    help="verbose output, default: %default")
parser.add_option("-b", "--input-bank-file", action="store", type="string",\
                   default=None,\
                   help="ASCII file containing the input bank.")
parser.add_option("-B", "--output-bank-file", action="store", type="string",\
                   default=None,\
                   help="ASCII file containing the output bank.")
parser.add_option("-o", "--pn-order", action="store", type="string",\
                   default=None,\
                   help="""Determines the PN order to use, choices are:
    * "twoPN": will include spin and non-spin terms up to 2PN in phase
    * "threePointFivePN": will include non-spin terms to 3.5PN, spin to 2.5PN
    * "taylorF4_45PN": use the R2D2 metric with partial terms to 4.5PN""")
parser.add_option("-m", "--metric-evals-file", action="store", type="string",\
                   default=None,\
                   help="ASCII file containing the metric eigenvalues.")
parser.add_option("-M", "--metric-evecs-file", action="store", type="string",\
                   default=None,\
                   help="ASCII file containing the metric eigenvectors.")
parser.add_option("-N", "--cov-evecs-file", action="store", type="string",\
                   default=None,\
                   help="ASCII file containing the covariance eigenvectors.")
parser.add_option("-f", "--f0", action="store", type="float",\
                  default=70., help="f0 for use in metric calculation," +\
                                    "default: %default")
parser.add_option("-y", "--min-mass1", action="store", type="float",\
                  default=0.03, help="Minimum mass1 to generate bank with"+\
                                     ", mass1 *must* be larger than mass2" +\
                                      "default: %default")
parser.add_option("-Y", "--max-mass1", action="store", type="float",\
                  default=0.03, help="Maximum mass1 to generate bank with"+\
                                      "default: %default")
parser.add_option("-z", "--min-mass2", action="store", type="float",\
                  default=0.03, help="Minimum mass2 to generate bank with"+\
                                      "default: %default")
parser.add_option("-Z", "--max-mass2", action="store", type="float",\
                  default=0.03, help="Maximum mass2 to generate bank with"+\
                                      "default: %default")
parser.add_option("-x", "--max-ns-spin-mag", action="store", type="float",\
                  default=0.03, help="Maximum neutron star spin magnitude"+\
                                      "default: %default")
parser.add_option("-X", "--max-bh-spin-mag", action="store", type="float",\
                  default=0.03, help="Maximum black hole spin magnitude"+\
                                      "default: %default")
parser.add_option("-O", "--min-match", action="store", type="float",\
                  default=0.03, help="Minimum match to generate bank with"+\
                                      "default: %default")
parser.add_option("-n", "--nsbh-flag", action="store_true", default=False,\
                    help="Set this if running with NSBH, default: %default")
parser.add_option("-s", "--stack-distance", action="store", type="float",\
                  default=0.2, help="Minimum metric spacing before we stack"+\
                                      "default: %default")
parser.add_option("-3", "--threed-lattice", action="store_true", default=False,\
                    help="Set this to use a 3D lattice, default: %default")
parser.add_option("-a", "--skip-vec4-depth",action="store_true", default=False,\
                    help="Set this if running with NSBH, default: %default")
parser.add_option("-F", "--skip-vec5-depth",action="store_true", default=False,\
                    help="Set this if running with NSBH, default: %default")



(opts,args) = parser.parse_args()
opts.eval_vec4_depth = not opts.skip_vec4_depth
opts.eval_vec5_depth = not opts.skip_vec5_depth

tempBank = numpy.loadtxt(opts.input_bank_file)

evals = numpy.loadtxt(opts.metric_evals_file)
evecs = numpy.matrix(numpy.loadtxt(opts.metric_evecs_file))
evecsCV = numpy.matrix(numpy.loadtxt(opts.cov_evecs_file))

rTotmass,rEta,rBeta,rSigma,rGamma,rSpin1z,rSpin2z = \
    geom_aligned_bank_utils.get_random_mass_slimline(\
    2000000,opts.min_mass1,opts.max_mass1,opts.min_mass2,opts.max_mass2,\
    opts.max_ns_spin_mag,maxBHspin = opts.max_bh_spin_mag,return_spins=True)
diff = (rTotmass*rTotmass * (1-4*rEta))**0.5
rMass1 = (rTotmass + diff)/2.
rMass2 = (rTotmass - diff)/2.
rChis = (rSpin1z + rSpin2z)/2.

rXis = geom_aligned_bank_utils.get_cov_params(rTotmass,rEta,rBeta,rSigma,rGamma,rChis,opts.f0,evecs,evals,evecsCV,opts.pn_order)

xis = (numpy.array(rXis)).T
physMasses = numpy.array([rTotmass,rEta,rSpin1z,rSpin2z])
physMasses = physMasses.T
f0 = opts.f0
order = opts.pn_order
maxmass1 = opts.max_mass1
maxmass2 = opts.max_mass2
minmass1 = opts.min_mass1
minmass2 = opts.min_mass2
maxNSspin = opts.max_ns_spin_mag
maxBHspin = opts.max_bh_spin_mag

# Here we would start looping over bank
temp_number = 0
outFile = opts.output_bank_file
outPointer = open(outFile,'w')
outPointer2 = open(outFile+'.reject','w')
outPointer3 = open(outFile+'.depths','w')
numtemps = len(tempBank)
for entry in tempBank:
  temp_number += 1
  if opts.verbose:
    print "Analysing template %d" %(temp_number)
  xi1_des = entry[0]
  xi2_des = entry[1]
  xis_des = [xi1_des,xi2_des]
  if opts.threed_lattice:
    xi3_des = entry[2]
    xis_des.append(xi3_des)
  req_match = 0.0001
  dist = (xi1_des - xis[:,0])**2 + (xi2_des - xis[:,1])**2
  if opts.threed_lattice:
    dist += (xi3_des - xis[:,2])**2
  xis_close = xis[dist < 0.03]
  masses_close = physMasses[dist < 0.03]
  bestMasses = physMasses[dist.argmin()]
  bestXis = xis[dist.argmin()]
  if opts.verbose:
    print "Template %d has initial distance of %e" %(temp_number,dist.min())
  # First find a point within MM of the desired xi position
  if dist.min() > 2.:
    # Reject point, it is too far away
    if opts.verbose:
      print "Template %d rejected as too far away" %(temp_number)
    if opts.threed_lattice:
      print >> outPointer2, "%e %e %e %e %e %e %e %e" %(xi1_des,xi2_des,xi3_des,0,0,0,0,dist.min())
    else:
      print >> outPointer2, "%e %e %e %e %e %e %e" %(xi1_des,xi2_des,0,0,0,0,dist.min())
    continue
  masses = geom_aligned_bank_utils.get_physical_covaried_masses(xis_des,copy.deepcopy(bestMasses),copy.deepcopy(bestXis),f0,temp_number,req_match,order,evecs,evals,evecsCV,maxmass1,minmass1,maxmass2,minmass2,maxNSspin,maxBHspin,return_smaller=True,nsbh_flag=opts.nsbh_flag)
  if opts.verbose:
    print "Template %d has corrected distance of %e" %(temp_number,masses[5])
  if masses[5] > 0.05:
    # Reject point, it is too far away
    if opts.verbose:
      print "Template %d rejected as too far away" %(temp_number)
    if opts.threed_lattice:
      print >> outPointer2, "%e %e %e %e %e %e %e %e" %(xi1_des,xi2_des,xi3_des,masses[0],masses[1],masses[2],masses[3],masses[5])
    else:
      print >> outPointer2, "%e %e %e %e %e %e %e" %(xi1_des,xi2_des,masses[0],masses[1],masses[2],masses[3],masses[5])
    continue
  tmpTotMass = masses[0] + masses[1]
  tmpEta = masses[0] * masses[1] / (tmpTotMass*tmpTotMass)
  if not opts.threed_lattice:
    # If point is close enough, determine depth of xi3 direction
    vec3_min,vec3_max,vec3_depth=geom_aligned_bank_utils.stack_xi_direction_brute([masses[6],masses[7]],[tmpTotMass,tmpEta,masses[2],masses[3]],copy.deepcopy(bestXis),f0,temp_number,2,opts.min_match,order,evecs,evals,evecsCV,maxmass1,minmass1,maxmass2,minmass2,maxNSspin,maxBHspin,nsbh_flag=opts.nsbh_flag)
    # Double check that there are no points with depth outside what was above
    if len(xis_close):
      if vec3_min > xis_close[:,2].min():
        print "WARNING: Numerical placement fails, trying again"
        temp_idx = xis_close[:,2].argmin()
        temmpBestMasses = masses_close[temp_idx]
        temmpBestXis = xis_close[temp_idx]
        temmpvec3_min,temmpvec3_max,temmpvec3_depth=geom_aligned_bank_utils.stack_xi_direction_brute([xi1_des,xi2_des],copy.deepcopy(temmpBestMasses),copy.deepcopy(temmpBestXis),f0,temp_number,2,opts.min_match,order,evecs,evals,evecsCV,maxmass1,minmass1,maxmass2,minmass2,maxNSspin,maxBHspin,nsbh_flag=opts.nsbh_flag)
        if temmpvec3_min < vec3_min:
          vec3_min = temmpvec3_min
          vec3_depth = vec3_max - vec3_min
        if temmpvec3_max > vec3_max:
          vec3_max = temmpvec3_max
          vec3_depth = vec3_max - vec3_min
      if vec3_max < xis_close[:,2].max():
        print "WARNING: Numerical placement fails, trying again"
        temp_idx = xis_close[:,2].argmax()
        temmpBestMasses = physMasses[temp_idx]
        temmpBestXis = xis[temp_idx]
        temmpvec3_min,temmpvec3_max,temmpvec3_depth=geom_aligned_bank_utils.stack_xi_direction_brute([xi1_des,xi2_des],copy.deepcopy(temmpBestMasses),copy.deepcopy(temmpBestXis),f0,temp_number,2,opts.min_match,order,evecs,evals,evecsCV,maxmass1,minmass1,maxmass2,minmass2,maxNSspin,maxBHspin,nsbh_flag=opts.nsbh_flag)
        if temmpvec3_max > vec3_max:
          vec3_max = temmpvec3_max
          vec3_depth = vec3_max - vec3_min
        if temmpvec3_min < vec3_min:
          vec3_min = temmpvec3_min
          vec3_depth = vec3_max - vec3_min
  # Determine depth of xi4 direction 
  if opts.eval_vec4_depth:
    vec4_min,vec4_max,vec4_depth=geom_aligned_bank_utils.stack_xi_direction_brute([masses[6],masses[7]],[tmpTotMass,tmpEta,masses[2],masses[3]],copy.deepcopy(bestXis),f0,temp_number,3,opts.min_match,order,evecs,evals,evecsCV,maxmass1,minmass1,maxmass2,minmass2,maxNSspin,maxBHspin,nsbh_flag=opts.nsbh_flag)
  else:
    vec4_min = vec4_max = vec4_depth = 0
  # Determine depth of xi5 direction 
  if opts.eval_vec5_depth:
    vec5_min,vec5_max,vec5_depth=geom_aligned_bank_utils.stack_xi_direction_brute([masses[6],masses[7]],[tmpTotMass,tmpEta,masses[2],masses[3]],copy.deepcopy(bestXis),f0,temp_number,4,opts.min_match,order,evecs,evals,evecsCV,maxmass1,minmass1,maxmass2,minmass2,maxNSspin,maxBHspin,nsbh_flag=opts.nsbh_flag)
  else:
    vec5_min = vec5_max = vec5_depth = 0
  # Output depths
  if opts.threed_lattice:
    print >> outPointer3, "%e %e %e %e %e" %(xi1_des,xi2_des,xi3_des,vec4_depth,vec5_depth)
  else:
    print >> outPointer3, "%e %e %e %e %e" %(xi1_des,xi2_des,vec3_depth,vec4_depth,vec5_depth)
  # Figure out how many templates we need to stack in 3rd direction
  vec3DepthVal = opts.stack_distance
  if opts.threed_lattice:
    numV3Temps = 1
  else:
    numV3Temps = int(round(vec3_depth // vec3DepthVal)) + 1
  for ite in range(numV3Temps):
    if not opts.threed_lattice:
      xi3_des = vec3_min + (vec3_depth) * (2 * ite + 1) / (2. * (numV3Temps))
    dist = (xi1_des - xis[:,0])**2 + (xi2_des - xis[:,1])**2 + (xi3_des - xis[:,2])**2
    bestMasses = physMasses[dist.argmin()]
    bestXis = xis[dist.argmin()]
    # Find close point to this 3d position
    masses = geom_aligned_bank_utils.get_physical_covaried_masses([xi1_des,xi2_des,xi3_des],copy.deepcopy(bestMasses),copy.deepcopy(bestXis),f0,temp_number,req_match,order,evecs,evals,evecsCV,maxmass1,minmass1,maxmass2,minmass2,maxNSspin,maxBHspin,return_smaller=True,nsbh_flag=opts.nsbh_flag) 
    # If vec4 depth is negligible or we didn't get close, stop here
    if vec4_depth < vec3DepthVal or masses[5] > 0.03:
      if masses[5]:
        print >> outPointer, "%e %e %e %e %e" %(masses[0],masses[1],masses[2],masses[3],masses[5])
    else:
      # OR we need to estimate the depth in the 4th direction at this 3d point
      tmpTotMass = masses[0] + masses[1]
      tmpEta = masses[0] * masses[1] / (tmpTotMass*tmpTotMass)
      vec4_minT,vec4_maxT,vec4_depthT=geom_aligned_bank_utils.stack_xi_direction_brute([masses[6],masses[7],masses[8]],[tmpTotMass,tmpEta,masses[2],masses[3]],copy.deepcopy(bestXis),f0,temp_number,3,opts.min_match,order,evecs,evals,evecsCV,maxmass1,minmass1,maxmass2,minmass2,maxNSspin,maxBHspin,nsbh_flag=opts.nsbh_flag)
      # Then loop over necessary templates in 4th direction
      numV4Temps = int(round(vec4_depthT // vec3DepthVal)) + 1
      for ite in range(numV4Temps):
        xi4_des = vec4_minT + (vec4_depthT) * (2 * ite + 1) / (2. * (numV4Temps))
        dist = (xi1_des - xis[:,0])**2 + (xi2_des - xis[:,1])**2 + (xi3_des - xis[:,2])**2 + (xi4_des - xis[:,3])**2
        bestMasses = physMasses[dist.argmin()]
        bestXis = xis[dist.argmin()]
        masses = geom_aligned_bank_utils.get_physical_covaried_masses([xi1_des,xi2_des,xi3_des,xi4_des],copy.deepcopy(bestMasses),copy.deepcopy(bestXis),f0,temp_number,req_match,order,evecs,evals,evecsCV,maxmass1,minmass1,maxmass2,minmass2,maxNSspin,maxBHspin,return_smaller=True,nsbh_flag=opts.nsbh_flag)
        if masses[5]:
          print >> outPointer, "%e %e %e %e %e" %(masses[0],masses[1],masses[2],masses[3],masses[5])

outPointer.close()
