#!/usr/bin/env python
# Copyright (C) 2010  Nickolas Fotopoulos
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""
Given an XML file with a SimInspiralTable, align each binary so that the total
angular momentum has the recorded inclination instead of the orbital
angular momentum. This will alter inclination, spins, and effective distance.

Usage: %prog --output-file OUTFILE siminsp_fname
"""
from __future__ import division

__author__ = "Nickolas Fotopoulos <nickolas.fotopoulos@ligo.org>"

import math
import optparse
import sys

import numpy as np

from glue.ligolw import ligolw
from glue.ligolw import lsctables
from glue.ligolw import table
from glue.ligolw import utils
from glue.ligolw.utils import process as ligolw_process
from pylal import date
from pylal import git_version
from pylal import sphericalutils as su
from pylal import inject
from lal import PI as LAL_PI
from lal import TWOPI as LAL_TWOPI
from lal import MTSUN_SI as LAL_MTSUN_SI


#
# Utility functions
#

def eff_distance(detector, sim):
    """
    Return the effective distance.

    Ref: Duncan's PhD thesis, eq. (4.3) on page 57, implemented in
         LALInspiralSiteTimeAndDist in SimInspiralUtils.c:594.
    """
    f_plus, f_cross = inject.XLALComputeDetAMResponse(detector.response,
        sim.longitude, sim.latitude, sim.polarization, sim.end_time_gmst)
    ci = math.cos(sim.inclination)
    s_plus = -(1 + ci * ci)
    s_cross = -2 * ci
    return 2 * sim.distance / math.sqrt(f_plus * f_plus * s_plus * s_plus + \
        f_cross * f_cross * s_cross * s_cross)

def get_spins(sim):
    """
    Return the 3-vectors s1, s2 from a sim_inspiral row.
    """
    return np.array([sim.spin1x, sim.spin1y, sim.spin1z]), \
           np.array([sim.spin2x, sim.spin2y, sim.spin2z])

def set_spins(sim, s1, s2):
    """
    Set the components of the spin in a sim_inspiral row from 3-vectors.
    """
    sim.spin1x = s1[0]
    sim.spin1y = s1[1]
    sim.spin1z = s1[2]
    sim.spin2x = s2[0]
    sim.spin2y = s2[1]
    sim.spin2z = s2[2]

def alignTotalSpin(iota, s1, s2, m1, m2, f_lower):
    """
    Rigidly rotate binary so that the total angular momentum has the given
    inclination (iota) instead of the orbital angular momentum. Return
    the new inclination, s1, and s2. s1 and s2 are dimensionless spin.

    Adapted from Michal Was's implementation in the X Pipeline:
    https://geco.phys.columbia.edu/xpipeline/browser/trunk/utilities/alignTotalSpin.m
    """
    m1sq = m1 * m1
    m2sq = m2 * m2
    # Rigidly rotate (S1, S2) around L into the XZ plane
    # NB: inspinj arranges L to initially be in the XZ plane
    Lhat = np.array([np.sin(iota), 0, np.cos(iota)])
    e1 = np.cross(Lhat, [0, 1, 0])
    e2 = np.cross(Lhat, e1)
    S = m1sq * s1 + m2sq * s2  # give units
    phiOutPlane = np.arctan2(np.dot(S, e2), np.dot(S, e1))
    news1 = su.rotate_about_axis(s1, Lhat, -phiOutPlane)
    news2 = su.rotate_about_axis(s2, Lhat, -phiOutPlane)

    # Compute total angular momentum, which now must be in the XZ plane
    angmomfac = m1 * m2 * (LAL_PI * LAL_MTSUN_SI * (m1 + m2) * f_lower)**(-1.0/3.0)
    J = angmomfac * Lhat + m1sq * news1 + m2sq * news2
    iotaJ = np.arctan2(J[0], J[2])

    # Rigidly rotate (L, S1, S2) around the Y axis so that J has incl. iota.
    # This is a rotation of the amount (iota - iotaJ).
    newIota = iota + (iota - iotaJ)
    news1 = su.rotate_about_axis(news1, [0, 1, 0], iota - iotaJ)
    news2 = su.rotate_about_axis(news2, [0, 1, 0], iota - iotaJ)

    # Sanity check: J has the desired inclination
    Jf = angmomfac * np.array([np.sin(newIota), 0, np.cos(newIota)]) \
        + m1sq * news1 + m2sq * news2
    assert abs(iota - np.arctan2(np.sqrt(Jf[0] * Jf[0] + Jf[1] * Jf[1]),
                                 Jf[2])) < 1e-4

    # Sanity check: J has its original magnitude
    J0 = angmomfac * np.array([np.sin(iota), 0, np.cos(iota)]) \
        + m1sq * s1 + m2sq * s2
    assert abs(np.dot(J0, J0) - np.dot(Jf, Jf)) < angmomfac*angmomfac * 1e-4

    return newIota, news1, news2


#
# Parse commandline
#

parser = optparse.OptionParser(usage=__doc__, version=git_version.verbose_msg)
parser.add_option("--output-file", metavar="OUTFILE",
    help="name of output XML file")
opts, args = parser.parse_args()

# everything is required
if (opts.output_file is None) or \
   (len(args) != 1):
    parser.print_usage()
    sys.exit(2)

siminsp_fname = args[0]

#
# Read inputs
#

siminsp_doc = utils.load_filename(siminsp_fname,
    gz=siminsp_fname.endswith(".gz"), contenthandler = lsctables.use_in(ligolw.LIGOLWContentHandler))

# Prepare process table with information about the current program
process = ligolw_process.register_to_xmldoc(siminsp_doc,
    "ligolw_cbc_align_total_spin",
    opts.__dict__, version=git_version.tag or git_version.id,
    cvs_repository="lalsuite", cvs_entry_time=git_version.date)


#
# Do the work
#
site_location_list = [\
    ("h", inject.cached_detector["LHO_4k"]),
    ("l", inject.cached_detector["LLO_4k"]),
    ("g", inject.cached_detector["GEO_600"]),
    ("t", inject.cached_detector["TAMA_300"]),
    ("v", inject.cached_detector["VIRGO"])]
for sim in table.get_table(siminsp_doc, lsctables.SimInspiralTable.tableName):
    # rotate
    s1, s2 = get_spins(sim)
    newIota, news1, news2 = alignTotalSpin(sim.inclination, s1, s2,
        sim.mass1, sim.mass2, sim.f_lower)
    sim.inclination = newIota
    set_spins(sim, news1, news2)

    # update arrival times and effective distances at the sites
    for site, detector in site_location_list:
        setattr(sim, "eff_dist_" + site, eff_distance(detector, sim))


#
# Write output
#

ligolw_process.set_process_end_time(process)
utils.write_filename(siminsp_doc, opts.output_file)
