#! /usr/bin/env python
#
# Copyright (C) 2012 Stephen Privitera
# Copyright (C) 2011-2014 Chad Hanna
# Copyright (C) 2010 Melissa Frei
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#   
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


import itertools
import math
import os
import sys
from optparse import OptionParser
from pylal import spawaveform
from pylal import rate
from glue.ligolw import ligolw
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from glue.ligolw.utils import process as ligolw_process
from lal import MSUN_SI
from glue import lal as gluelal
from pylal.datatypes import LIGOTimeGPS
from gstlal import templates
from gstlal import inspiral_pipe
from gstlal import chirptime

## @file gstlal_bank_splitter
#
# This program splits template banks into sub banks suitable for singular value decomposition; see gstlal_bank_splitter for more information
# 
# ### Usage examples
#
# - split up bank file for H1; sort by mchirp; add final frequency and specify a maximum frequency
#
#		$ gstlal_bank_splitter --overlap 10 --instrument H1 --n 100 --sort-by mchirp --add-f-final --max-f-final 2048 H1-TMPLTBANK-871147516-2048.xml
#
# - Please add more!
#
# ### Command line interface
#
#	+ `--output-path` [path]: Set the path to the directory where output files will be written.  Default is "."
#	+ `--output-cache` [file]: Set the file name for the output cache.
#	+ `--n` [count] (int): Set the number of templates per output file (required).  It will be rounded to make all sub banks approximately the same size.
#	+ `--overlap` [count] (int): Overlap the templates in each file by this amount, must be even.
#	+ `--sort-by` [column]: Select the template sort order column (required).
#	+ `--max-f-final` [max final freq] (float): Max f_final to populate table with; if f_final > max, use max.
#	+ `--instrument` [ifo]: Override the instrument, required
#	+ `--verbose`: Be verbose.
#	+ `--approximant` [mchirp_min:mchirp_max:string]: Must specify an approximant and validity range, can be given more than once
#	+ `--f-low` [frequency] (floate): Lower frequency cutoff
#	+ `--group-by-chi` [Number] (int): group templates into N chi bins with uniform number of templates - helps with SVD.
#
# ### Review status
#
# Compared original bank with the split banks.  Verified that they are the same, e.g., add sub bank files into test.xml.gz and run (except that lalapps_tmpltbank adds redundant templates):
#
#		ligolw_print -t sngl_inspiral -c mass1 -c mass2 ../H1-TMPLTBANK-871147516-2048.xml | sort -u | wc
#		ligolw_print -t sngl_inspiral -c mass1 -c mass2 test.xml.gz | sort -u | wc
#
# | Names 	                                        | Hash 					                   | Date       | Diff to Head of Master      |
# | ----------------------------------------------- | ---------------------------------------- | ---------- | --------------------------- |
# | Florent, Sathya, Duncan Me., Jolien, Kipp, Chad | 7536db9d496be9a014559f4e273e1e856047bf71 | 2014-04-28 | <a href="@gstlal_inspiral_cgit_diff/bin/gstlal_bank_splitter?id=HEAD&id2=7536db9d496be9a014559f4e273e1e856047bf71">gstlal_bank_splitter</a> |
#
# #### Action
#
# - Consider cleanup once additional bank programs are used and perhaps have additional metadata
#

class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
	pass
lsctables.use_in(LIGOLWContentHandler)


def group_templates(templates, n, overlap = 0):
	"""
	break up the template table into sub tables of length n with overlap
	overlap.  n must be less than the number of templates and overlap must be less
	than n
	"""
	if n >= len(templates):
		yield templates
	else:
		n = len(templates) / round(len(templates) / float(n))
		assert n >= 1
		for i in itertools.count():
			start = int(round(i * n)) - overlap // 2
			end = int(round((i + 1) * n)) + overlap // 2
			yield templates[max(start, 0):end]
			if end >= len(templates):
				break

def parse_command_line():
	parser = OptionParser()
	parser.add_option("--output-path", metavar = "path", default = ".", help = "Set the path to the directory where output files will be written.  Default is \".\".")
	parser.add_option("--output-cache", metavar = "file", help = "Set the file name for the output cache.")
	parser.add_option("--n", metavar = "count", type = "int", help = "Set the number of templates per output file (required). It will be rounded to make all sub banks approximately the same size.")
	parser.add_option("--overlap", default = 0, metavar = "count", type = "int", help = "overlap the templates in each file by this amount, must be even")
	parser.add_option("--sort-by", metavar = "column", default="mchirp", help = "Select the template sort column, default mchirp")
	parser.add_option("--max-f-final", metavar = "float", type="float", help = "Max f_final to populate table with; if f_final over mx, use max.")
	parser.add_option("--instrument", metavar = "ifo", type="string", help = "override the instrument, required")
	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
	parser.add_option("--approximant", type = "string", action = "append", help = "Must specify an approximant given as mchirp_min:mchirp_max:string")
	parser.add_option("--f-low", type = "float", metavar = "frequency", help = "Lower frequency cutoff. Required")
	parser.add_option("--group-by-chi", type = "int", metavar = "N", default = 1, help = "group templates into N groups of chi - helps with SVD. Default 1")
	options, filenames = parser.parse_args()

	required_options = ("n", "instrument", "sort_by", "output_cache", "approximant", "f_low")
	missing_options = [option for option in required_options if getattr(options, option) is None]
	if missing_options:
		raise ValueError, "missing required option(s) %s" % ", ".join("--%s" % option.replace("_", "-") for option in missing_options)

	if options.overlap % 2:
		raise ValueError("overlap must be even")

	if len(filenames) !=1:
		raise ValueError("Must give exactly one file name")

	approximants = []
	for appx in options.approximant:
		mn, mx, appxstring = appx.split(":")
		approximants.append((float(mn), float(mx), appxstring))
	options.approximant = approximants

	return options, filenames[0]

options, filename = parse_command_line()

def assign_approximant(mchirp, approximants = options.approximant):
	for appx in approximants:
		if mchirp >= appx[0] and mchirp < appx[1]:
			return appx[2]
	raise ValueError("Valid approximant not given for this chirp mass")
	

output_cache_file = open(options.output_cache, "w")
bank_count = 0

outputrows = []

xmldoc = ligolw_utils.load_filename(filename, verbose = options.verbose, contenthandler = LIGOLWContentHandler)
sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)

# Bin by Chi
sngl_inspiral_table.sort(key = lambda row: spawaveform.computechi(row.mass1, row.mass2, row.spin1z, row.spin2z))
for chirows in group_templates(sngl_inspiral_table, len(sngl_inspiral_table) / options.group_by_chi, overlap = 0):

	def sort_func(row, column = options.sort_by):
		return getattr(row, column)

	chirows.sort(key=sort_func)

	for numrow, rows in enumerate(group_templates(chirows, options.n, options.overlap)):
		assert len(rows) >= options.n/2, "There are too few templates in this chi interval.  Requested %d: have %d" % (options.n, len(rows))
		# Pad the first group with an extra overlap / 2 templates
		if numrow == 0:
			rows = rows[:options.overlap/2] + rows
		outputrows.append((rows[0], rows))
	# Pad the last group with an extra overlap / 2 templates
	outputrows[-1] = (rows[0], rows + rows[-options.overlap/2:])

# A sort of the groups of templates so that the sub banks are ordered.
def sort_func((row, rows), column = options.sort_by):
	return getattr(row, column)

outputrows.sort(key=sort_func)

for bank_count, (_, rows) in enumerate(outputrows):
	# just choose the first row to get mchirp
	# FIXME this could certainly be better
	approximant = assign_approximant(rows[0].mchirp) 

	# Make an output document
	xmldoc = ligolw.Document()
	lw = xmldoc.appendChild(ligolw.LIGO_LW())
	sngl_inspiral_table = lsctables.New(lsctables.SnglInspiralTable)
	lw.appendChild(sngl_inspiral_table)
	# Override the approximant
	options.approximant = approximant
	# store the process params
	process = ligolw_process.register_to_xmldoc(xmldoc, program = "gstlal_bank_splitter", paramdict = options.__dict__, comment = "Split bank into smaller banks for SVD")

	for row in rows:
		# Chirptime uses SI
		m1_SI, m2_SI = MSUN_SI * row.mass1, MSUN_SI * row.mass2
		# Find the total spin magnitudes
		spin1, spin2 = (row.spin1x**2 + row.spin1y**2 + row.spin1z**2)**.5, (row.spin2x**2 + row.spin2y**2 + row.spin2z**2)**.5

		if approximant in templates.gstlal_IMR_approximants:
			# make sure to go a factor of 2 above the ringdown frequency for safety
			row.f_final = 2 * chirptime.ringf(m1_SI + m2_SI, chirptime.overestimate_j_from_chi(max(spin1, spin2)))
		else:
			# otherwise choose a suitable high frequency
			# NOTE not SI
			row.f_final = spawaveform.ffinal(row.mass1, row.mass2, 'bkl_isco')

		# Override the high frequency with the max if appropriate
		if options.max_f_final and (row.f_final > options.max_f_final):
			row.f_final = options.max_f_final
		row.process_id = process.process_id


		# Record the conservative template duration
		row.template_duration = chirptime.imr_time(options.f_low, m1_SI, m2_SI, spin1, spin2, f_max = row.f_final)

		# Make sure ifo and total mass is stored
		row.ifo = options.instrument
		row.mtotal = row.mass1 + row.mass2

	sngl_inspiral_table[:] = rows
	output = inspiral_pipe.T050017_filename(options.instrument, "GSTLAL_SPLIT_BANK_%04d" % bank_count, 0, 0, ".xml.gz", path = options.output_path)
	output_cache_file.write("%s\n" % gluelal.CacheEntry.from_T050017("file://localhost%s" % os.path.abspath(output)))
	ligolw_utils.write_filename(xmldoc, output, gz = output.endswith('gz'), verbose = options.verbose)

output_cache_file.close()
