#!/usr/bin/env python

# Copyright (C) 2011 Ian W. Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Stochastic aligned spin bank generator.
"""
from __future__ import division
import warnings

def _warning(
    message,
    category = UserWarning,
    filename = '',
    lineno = -1):
    print(message)

warnings.showwarning = _warning

warnings.warn("""DEPRECATION WARNING:
THE FOLLOWING CODE HAS BEEN DEPRECATED AND MOVED INTO THE PYCBC PACKAGE. IT WILL BE REMOVED FROM PYLAL AT SOME POINT IN THE NEAR FUTURE.

FOR NOW PLEASE CONSULT:

https://ldas-jobs.ligo.caltech.edu/~cbc/docs/pycbc/tmpltbank.html

FOR INSTRUCTIONS OF HOW TO USE THE REPLACEMENTS IN PYCBC. PYCBC BUILD INSTRUCTIONS ARE HERE

https://ldas-jobs.ligo.caltech.edu/~cbc/docs/pycbc/index.html""", DeprecationWarning)

import matplotlib
matplotlib.use('Agg')
import pylab
import time
start = int(time.time()*10**6)
elapsed_time = lambda: int(time.time()*10**6-start)

import os,sys,optparse,copy
import tempfile
import ConfigParser
import numpy
from pylal import geom_aligned_bank_utils,git_version
from glue import pipeline
from lal import PI as LAL_PI
from lal import MTSUN_SI as LAL_MTSUN_SI

__author__  = "Ian Harry <ian.harry@astro.cf.ac.uk>"
__version__ = "git id %s" % git_version.id
__date__    = git_version.date

# Read command line options
usage = """usage: %prog [options]"""
_desc = __doc__[1:]
parser = optparse.OptionParser(usage, version=__version__, description=_desc)
parser.add_option("-v", "--verbose", action="store_true", default=False,\
                    help="verbose output, default: %default")
parser.add_option("-a", "--psd-file", action="store", type="string",\
                   default=None,\
                   help="ASCII file containing the PSD.")
parser.add_option("-o", "--pn-order", action="store", type="string",\
                   default=None,\
                   help="""Determines the PN order to use, choices are:
    * "twoPN": will include spin and non-spin terms up to 2PN in phase
    * "threePointFivePN": will include non-spin terms to 3.5PN, spin to 2.5PN
    * "taylorF4_45PN": use the R2D2 metric with partial terms to 4.5PN""")
parser.add_option("-f", "--f0", action="store", type="float",\
                  default=70., help="f0 for use in metric calculation," +\
                                    "default: %default")
parser.add_option("-l", "--f-low", action="store", type="float",\
                  default=15., help="f_low for use in metric calculation," +\
                                    "default: %default")
parser.add_option("-u", "--f-upper", action="store", type="float",\
                  default=2000., help="f_up for use in metric calculation," +\
                                      "default: %default")
parser.add_option("-d", "--delta-f", action="store", type="float",\
                  default=0.001, help="delta_f for use in metric calculation,"+\
                                      "linear interpolation used to get this,"+\
                                      "default: %default")
parser.add_option("-m", "--min-match", action="store", type="float",\
                  default=0.03, help="Minimum match to generate bank with"+\
                                      "default: %default")
parser.add_option("-y", "--min-mass1", action="store", type="float",\
                  default=0.03, help="Minimum mass1 to generate bank with"+\
                                     ", mass1 *must* be larger than mass2" +\
                                      "default: %default")
parser.add_option("-Y", "--max-mass1", action="store", type="float",\
                  default=0.03, help="Maximum mass1 to generate bank with"+\
                                      "default: %default")
parser.add_option("-z", "--min-mass2", action="store", type="float",\
                  default=0.03, help="Minimum mass2 to generate bank with"+\
                                      "default: %default")
parser.add_option("-Z", "--max-mass2", action="store", type="float",\
                  default=0.03, help="Maximum mass2 to generate bank with"+\
                                      "default: %default")
parser.add_option("-W", "--max-total-mass", action="store", type="float",\
                  default=None, help="Set a maximum total mass"+\
                                      "default: %default")
parser.add_option("-x", "--max-ns-spin-mag", action="store", type="float",\
                  default=0.03, help="Maximum neutron star spin magnitude"+\
                                      "default: %default")
parser.add_option("-X", "--max-bh-spin-mag", action="store", type="float",\
                  default=0.03, help="Maximum black hole spin magnitude"+\
                                      "default: %default")
parser.add_option("-n", "--nsbh-flag", action="store_true", default=False,\
                    help="Set this if running with NSBH, default: %default")
parser.add_option("-c", "--covary", action="store_true", default=False,\
                    help="Set this to run with covaried coordinates. This will"\
                   +"provide a speed up, but cannot run with varying f-lower"\
                   +"yet. default: %default")
parser.add_option("-V", "--vary-fupper", action="store_true", default=False,\
                    help="Set this to run the code using a variable f_upper"\
                   +". default: %default")
parser.add_option("-N", "--num-seeds", action="store", type="int",\
                  default=100000000,\
                  help="Number of seed points to make bank," +\
                                    "default: %default")


(opts,args) = parser.parse_args()

# This can be used to replicate behaviour of sBANK
opts.num_failed_cutoff = 10000000

if opts.vary_fupper:
  opts.f_upper = 2000

opts.min_total_mass = opts.min_mass1 + opts.min_mass2
if not opts.max_total_mass:
  opts.max_total_mass = opts.max_mass1 + opts.max_mass2
opts.min_comp_mass = opts.min_mass2
opts.max_comp_mass = opts.max_mass2
opts.split_bank_num = 100

# This could be altered to do an exact match if desired
def dist(vsA,entryA,MMdistA):
  val = (vsA[0] - entryA[0])**2
  for i in range(1,len(vsA)):
    val += (vsA[i] - entryA[i])**2
  return (numpy.sqrt(val) < MMdistA)

# If varying f_upper we want to test the metric distances properly
def dist_vary(mus1,fUpper1,masses2,fMap,MMdistA):
  mus2 = masses2[5]
  idx1 = fMap[fUpper1]
  vecs1 = mus1[idx1]
  vecs2 = mus2[idx1]
  val = (vecs1[0] - vecs2[0])**2
  for i in range(1,len(vecs1)):
    val += (vecs1[i] - vecs2[i])**2
  if (numpy.sqrt(val) > MMdistA):  
    return False
  fUpper2 = masses2[4]
  idx2 = fMap[fUpper2]
  vecs1 = mus1[idx2]
  vecs2 = mus2[idx2]
  val = (vecs1[0] - vecs2[0])**2
  for i in range(1,len(vecs1)):
    val += (vecs1[i] - vecs2[i])**2
  return (numpy.sqrt(val) < MMdistA)

def return_nearest_isco(totmass,freqs):
  fISCO = (1/6.)**(3./2.) / (LAL_PI * totmass * LAL_MTSUN_SI)
  refEv = numpy.zeros(len(fISCO),dtype=float)
  for i in range(len(freqs)):
    if (i == 0):
      logicArr = fISCO < ((freqs[0] + freqs[1])/2.)
    if (i == (len(freqs)-1)):
      logicArr = fISCO > ((freqs[-2] + freqs[-1])/2.)
    else:
      logicArrA = fISCO > ((freqs[i-1] + freqs[i])/2.)
      logicArrB = fISCO < ((freqs[i] + freqs[i+1])/2.)
      logicArr = numpy.logical_and(logicArrA,logicArrB)
    if logicArr.any():
      refEv[logicArr] = freqs[i]
  return refEv

# Begin by calculating a metric
evals,evecs = geom_aligned_bank_utils.determine_eigen_directions(opts.psd_file,\
    opts.pn_order,opts.f0,opts.f_low,opts.f_upper,opts.delta_f,\
    verbose=opts.verbose,elapsed_time=elapsed_time,vary_fmax=opts.vary_fupper)

if not opts.vary_fupper:
  evalsSngl = evals['fixed']
  evecsSngl = evecs['fixed']
else:
  fs = numpy.array(evals.keys(),dtype=float)
  fs.sort()
  maxTMass = opts.max_total_mass
  minTMass = opts.min_total_mass
  lowEve = return_nearest_isco(numpy.array([maxTMass]),fs)[0]
  highEve = return_nearest_isco(numpy.array([minTMass]),fs)[0]
  print lowEve,highEve
  evalsSngl = evals[lowEve]
  evecsSngl = evecs[lowEve]

if opts.covary:
  if opts.verbose:
    print >>sys.stdout, "Calculating covariance matrix at %d." %(elapsed_time())

  vals = geom_aligned_bank_utils.estimate_mass_range_slimline(1000000,\
         opts.pn_order,evalsSngl,evecsSngl,opts.max_mass1,\
         opts.min_mass1,opts.max_mass2,opts.min_mass2,\
         opts.max_ns_spin_mag,opts.f0,maxmass=opts.max_total_mass,\
         covary=False,maxBHspin=opts.max_bh_spin_mag)
  cov = numpy.cov(vals)
  evalsCV,evecsCV = numpy.linalg.eig(cov)

  if opts.verbose:
    print>> sys.stdout, "Covariance matrix calculated at %d." %(elapsed_time())
else:
  evecsCV = None

if opts.verbose:
  print>> sys.stdout, "Determining parameter space extent %d." %(elapsed_time())

vals = geom_aligned_bank_utils.estimate_mass_range_slimline(1000000,\
       opts.pn_order,evalsSngl,evecsSngl,opts.max_mass1,\
       opts.min_mass1,opts.max_mass2,opts.min_mass2,\
       opts.max_ns_spin_mag,opts.f0,maxmass=opts.max_total_mass,\
       covary=opts.covary,evecsCV=evecsCV,maxBHspin=opts.max_bh_spin_mag)

pylab.plot(vals[0],vals[1],'b.')
pylab.savefig('testing.png')

chi1Max = vals[0].max()
chi1Min = vals[0].min()
chi1Diff = chi1Max - chi1Min
chi2Max = vals[1].max()
chi2Min = vals[1].min()
chi2Diff = chi2Max - chi2Min
chi1Min = chi1Min - 0.1*chi1Diff
chi1Max = chi1Max + 0.1*chi1Diff
chi2Min = chi2Min - 0.1*chi2Diff
chi2Max = chi2Max + 0.1*chi2Diff

if opts.verbose:
  print>> sys.stdout, "Determined parameter space extent %d." %(elapsed_time())
  print chi1Min,chi1Max,chi2Min,chi2Max


if opts.verbose:
  print>> sys.stdout, "Initializing bank set up at %d." %(elapsed_time())

vals = None

# Set up the bank into sections
massbank = {}
bank = {}
MMdist = (opts.min_match)**0.5
for i in range(int((chi1Max - chi1Min) // MMdist)):
  bank[i] = {}
  massbank[i] = {}
  for j in range(int((chi2Max - chi2Min) // MMdist)):
    bank[i][j] = []
    massbank[i][j] = []

maxi = int((chi1Max - chi1Min) // MMdist)
maxj = int((chi2Max - chi2Min) // MMdist)
# Initialise counters
N = 0
Np = 0
Ns = 0
Nr = 0

if opts.vary_fupper:
  freqMap = {}
  idx = 0
  for freq in fs:
    if freq >= lowEve and freq <= highEve:
      freqMap[freq] = idx
      idx += 1
      print freq

if opts.verbose:
  print>> sys.stdout, "Initialized bank and starting at %d." %(elapsed_time())

# Begin making the thing
outbins = [[0,0],[0,1],[1,0],[0,-1],[-1,0],[1,1],[1,-1],[-1,1],[-1,-1]]
while(1):
  if not (Ns % 100000):
    rTotmass,rEta,rBeta,rSigma,rGamma,rSpin1z,rSpin2z =\
        geom_aligned_bank_utils.get_random_mass_slimline(\
        100000,opts.min_mass1,opts.max_mass1,opts.min_mass2,opts.max_mass2,\
        opts.max_ns_spin_mag,maxBHspin = opts.max_bh_spin_mag,\
        return_spins=True,maxmass=opts.max_total_mass)
    diff = (rTotmass*rTotmass * (1-4*rEta))**0.5
    rMass1 = (rTotmass + diff)/2.
    rMass2 = (rTotmass - diff)/2.
    rChis = (rSpin1z + rSpin2z)/2.
    if opts.vary_fupper:
      refEve = return_nearest_isco(rTotmass,fs)
      lambdas =  geom_aligned_bank_utils.get_chirp_params(rTotmass,rEta,rBeta,\
          rSigma,rGamma,rChis,opts.f0,opts.pn_order)
#      lambdas = numpy.array(lambdas)
#      lambdas = lambdas.T
      mus = []
      idx = 0
      for freq in fs:
        if freq >= lowEve and freq <= highEve: 
          mus.append(geom_aligned_bank_utils.get_chi_params(lambdas,opts.f0,evecs[freq],evals[freq],opts.pn_order))
          if freqMap[freq] != idx:
            raise BrokenError
          idx += 1
      mus = numpy.array(mus)
    else:
      refEve = numpy.zeros(100000)
      mus = numpy.zeros([1,1,100000])
    if opts.covary:
      vecs = geom_aligned_bank_utils.get_cov_params(rTotmass,rEta,rBeta,rSigma,\
        rGamma,rChis,opts.f0,evecsSngl,evalsSngl,evecsCV,\
        opts.pn_order)
    else:
      vecs = geom_aligned_bank_utils.get_conv_params(rTotmass,rEta,rBeta,\
        rSigma,rGamma,rChis,opts.f0,evecsSngl,evalsSngl,opts.pn_order)
      
    vecs = numpy.array(vecs)
    Ns = 0
  if not (Np % 10000) and opts.verbose:
    print "Seeds",Np
  vs = vecs[:,Ns]
  v1Bin = int((vs[0] - chi1Min) // MMdist)
  v2Bin = int((vs[1] - chi2Min) // MMdist)
  store = True
  Np = Np + 1
  if opts.vary_fupper:
    for i,j in outbins:
      if store:
        for entry in massbank[v1Bin+i][v2Bin+j]:
          if dist_vary(mus[:,:,Ns],refEve[Ns],entry,freqMap,MMdist):
            store = False
            break
  else:
    for i,j in outbins:
      if store:
        for entry in bank[v1Bin+i][v2Bin+j]:
          if dist(vs,entry,MMdist):
            store = False
            break
  if not store:
    Ns = Ns + 1
    Nr = Nr + 1
    if Nr > opts.num_failed_cutoff:
      break
    continue
  Nr = 0
  bank[v1Bin][v2Bin].append([copy.deepcopy(vs[0]),copy.deepcopy(vs[1]),copy.deepcopy(vs[2]),copy.deepcopy(vs[3]),copy.deepcopy(vs[4]),copy.deepcopy(vs[5]),copy.deepcopy(vs[6]),copy.deepcopy(vs[7])])
  massbank[v1Bin][v2Bin].append([copy.deepcopy(rMass1[Ns]),copy.deepcopy(rMass2[Ns]),copy.deepcopy(rSpin1z[Ns]),copy.deepcopy(rSpin2z[Ns]),copy.deepcopy(refEve[Ns]),copy.deepcopy(mus[:,:,Ns])])
  N = N + 1
  if opts.verbose and not (N % 100000):
    print "Templates %d at %d" %(N,elapsed_time())
  if Np > opts.num_seeds:
    break
  Ns = Ns + 1  

if opts.verbose:
  print "Outputting at %d." %(elapsed_time())

outfile=open('stochastic_bank_mass.dat','w')
outfile2=open('stochastic_bank_evs.dat','w')

for i in range(int((chi1Max - chi1Min) // MMdist)):
  for j in range(int((chi2Max - chi2Min) // MMdist)):
    for entry in bank[i][j]:
      outfile2.write('%.16e %.16e %.16e %.16e %.16e %.16e %.16e %.16e\n' %(entry[0],entry[1],entry[2],entry[3],entry[4],entry[5],entry[6],entry[7]))
outfile2.close()

for i in range(int((chi1Max - chi1Min) // MMdist)):
  for j in range(int((chi2Max - chi2Min) // MMdist)):
    for masses in massbank[i][j]:
      outfile.write('%.16e %.16e %.16e %.16e\n' %(masses[0],masses[1],masses[2],masses[3]))
outfile.close()

  
