#!/usr/bin/env python

# Copyright (C) 2012 Duncan Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

## \addtogroup laldetchar_py_triggers_utils
"""Utilities for the triggers package.

This package includes ways of determining the correct LIGO_LW table
for a given trigger generator, and a few helper functions.
"""
#
# ### Synopsis ###
#
#~~~
#from laldetchar.triggers import utils
#~~~
#\author Duncan Macleod (<duncan.macleod@ligo.org>)

import numpy
import re

from glue import iterutils
from glue.ligolw import (lsctables, table as ligolw_table,
                         utils as ligolw_utils)

from laldetchar import git_version
__author__ = "Duncan M. Macleod <duncan.macleod@ligo.org>"
__version__ = git_version.id
__date__ = git_version.date

HACR_REGEX = re.compile("\Ahacr\Z", re.I)
KW_REGEX = re.compile("\A(kw|kleinewelle)\Z", re.I)
OMEGA_REGEX = re.compile("\Aomega", re.I)
OMICRON_REGEX = re.compile("\Aomicron\Z", re.I)
EXCESSPOWER_REGEX = re.compile('excesspower\Z', re.I)
CWB_REGEX = re.compile("(cwb|waveburst)", re.I)
IHOPE_REGEX = re.compile("ihope", re.I)
COH_PTF_REGEX = re.compile("coh_PTF", re.I)

SNGL_BURST_REGEX = re.compile("(%s)" % "|".join([HACR_REGEX.pattern,
                                                 KW_REGEX.pattern,
                                                 OMEGA_REGEX.pattern,
                                                 OMICRON_REGEX.pattern,
                                                 EXCESSPOWER_REGEX.pattern,
                                                 CWB_REGEX.pattern]))
MULTI_BURST_REGEX = CWB_REGEX
SNGL_INSPIRAL_REGEX = IHOPE_REGEX
MULTI_INSPIRAL_REGEX = COH_PTF_REGEX
SNGL_RING_REGEX = re.compile('ringdown', re.I)


def _which_etg(etg):
    """Find the correct trigger generator name based on the input

    ### Example: ###

    \code
    >>> which_etg("EP")
    'excesspower'
    >>> which_etg("KleineWelle")
    'kw'
    \endcode

    @returns the ETG name recognised by the triggers.from_file
    function
    """
    etg = str(etg).lower()
    if etg in lsctables.TableByName.keys():
        return etg
    elif KW_REGEX.search(etg):
        return 'kw'
    elif OMEGA_REGEX.search(etg):
        return 'omega'
    elif OMICRON_REGEX.search(etg):
        return 'omicron'
    elif CWB_REGEX.search(etg):
        return 'cwb'
    raise ValueError("No ETG name understood for input \'%s\'. "
                     "Either it writes XML natively, or there just "
                     "isn't an I/O wrapper for it")

# open doxygen
## \addtogroup laldetchar_py_triggers_utils
#@{


def which_table(etg):
    """Find the correct table to use for a given trigger generator

    ### Example: ###

    @code
    >>> which_table("ExcessPower")
    'sngl_burst'
    @endcode

    @param etg
        the name of the trigger generator in question

    @returns the name of the `LIGO_LW` table appropriate for triggers
             generated by the given etg
    """
    etg = str(etg).lower()
    if etg in lsctables.TableByName.keys():
        return etg
    elif MULTI_BURST_REGEX.search(etg):
        return ligolw_table.StripTableName(lsctables.MultiBurstTable.tableName)
    elif SNGL_BURST_REGEX.search(etg):
        return ligolw_table.StripTableName(lsctables.SnglBurstTable.tableName)
    elif MULTI_INSPIRAL_REGEX.search(etg):
        return ligolw_table.StripTableName(
                   lsctables.MultiInspiralTable.tableName)
    elif SNGL_INSPIRAL_REGEX.search(etg):
        return ligolw_table.StripTableName(
                   lsctables.SnglInspiralTable.tableName)
    elif SNGL_RING_REGEX.search(etg):
        return ligolw_table.StripTableName(
                   lsctables.SnglRingdownTable.tableName)
    else:
        raise ValueError("No LIGO_LW table mapped for ETG=\'%s\'" % etg)


def new_ligolw_table(etg, columns=None):
    """Generate a new LIGO_LW table for the given trigger generator

    @param etg
        the name of the trigger generator in question
    @param columns
        a list of valid `LIGO_LW` column names for the new table
        (defaults to all)

    @returns a new `LIGO_LW` table (`glue.ligolw.table.Table`) for the
             given trigger generator
    """
    table_name = which_table(etg)
    return lsctables.New(lsctables.TableByName[table_name], columns=columns)


def time_func(table_name, ifo=None):
    """Find the function that will return the 'time' for a row in
    a table.

    @param table_name
        the name of the relevant `LIGO_LW` table (e.g. `sngl_burst`)
    @param ifo
        an interferometer prefix if you want single-detector times

    @returns a `lambda` function that returns the 'time' value for
    a given `LIGO_LW` table row input
    """
    ligolw_name = which_table(table_name)
    TableType = lsctables.TableByName[ligolw_name]
    RowType = TableType.RowType
    if hasattr(RowType, "get_time"):
        return RowType.get_time
    elif ligolw_name == "sngl_burst":
        return RowType.get_peak
    elif ligolw_name in ["sngl_inspiral", "multi_inspiral"]:
        return RowType.get_end
    elif ligolw_name == "sim_inspiral":
        if ifo:
            return RowType.get_end
        else:
            site = ifo and ifo[0]
            return lambda row: row.get_end(site)
    elif ligolw_name == "sngl_ringdown":
        return RowType.get_start
    else:
        raise ValueError("No known time method for ligolw_name=\'%s\'"
                         % table_name)


def time_column(table, ifo=None):
    """Extract the 'time' column from the given table.

    This function uses time_func to determine the correct column to
    use as a proxy for 'time' and returns that column.
    The following mappings are used:
    - `sngl_inspiral` -> 'end' time
    - `sngl_burst` -> 'peak' time
    - `sngl_ringdown` -> 'start' time

    @param table
        any `LIGO_LW` table
    @param ifo
        an interferometer prefix if you want single-detector times

    @returns a numpy array object with a 'time' element for each row in
    the table
    """
    if hasattr(table, "get_time"):
        return numpy.asarray(table.get_time())
    func_name = time_func(ligolw_table.StripTableName(table.tableName)).__name__
    if hasattr(table, func_name):
        return numpy.asarray(getattr(table, func_name)())
    else:
        return numpy.asarray(map(func_name, table))


def from_ligolw(filepath, table_name, columns=None, start=None, end=None,
                **kwargs):
    """Load a LIGO_LW table from a file.

    @param filepath
        path to `LIGO_LW` XML file
    @param table_name
        name of the requested `LIGO_LW` table
    @param columns
        a list of valid `LIGO_LW` column names for the new table
        (defaults to all)
    @param start
        minimum GPS time for returned triggers
    @param end
        maximum GPS time for returned triggers
    @param kwargs
        UNDOCUMENTED

    @returns the requested `LIGO_LW` table.
    """
    table_name = which_table(table_name)
    # extract table with correct columns
    if columns:
        TableType = lsctables.TableByName[table_name]
        _oldcols = TableType.loadcolumns
        TableType.loadcolumns = columns
    # load file
    xmldoc = ligolw_utils.load_filename(filepath)
    out = ligolw_table.get_table(xmldoc, table_name)
    if start or end:
        time = time_func(table_name)
        start = start is not None and start or segments.NegInfinity
        end = end is not None and end or segments.PosInfinity
        keep = lambda row: ((start <= float(time(row))) & (time(row) < end))
        iterutils.inplace_filter(keep, out)
    if columns:
        TableType.loadcolumns = _oldcols
    return out


def from_ascii(filepath, etg, columns=None, start=None, end=None, **kwargs):
    """Load a LIGO_LW table from `ASCII` files

    @param filepath
        path to `ASCII` file
    @param etg
        name of the parent trigger generator
    @param columns
        a list of valid `LIGO_LW` column names for the new table
        (defaults to all)
    @param start
        minimum GPS time for returned triggers
    @param end
        maximum GPS time for returned triggers
    @param kwargs UNDOCUMENTED

    @returns a LIGO_LW table containing the triggers
    """
    etg = _which_etg(etg)
    try:
        _etg_mod = __import__("laldetchar.triggers.%s" % etg.lower(),
                              fromlist=[""])
    except ImportError:
        raise ImportError("No module found for etg='%s'" % etg)
    try:
        _load_triggers = getattr(_etg_mod, "from_ascii")
    except AttributeError:
        raise AttributeError("No 'from_ascii' function for etg='%s'" % etg)
    out = _load_triggers(filepath, columns=columns, start=start, end=end,
                         **kwargs)
    return out


def from_root(filepath, etg, columns=None, start=None, end=None, **kwargs):
    """Load a `LIGO_LW` table from `ROOT` files

    @param filepath
        path to `ROOT` file
    @param etg
        name of the parent trigger generator
    @param columns
        a list of valid `LIGO_LW` column names for the new table
        (defaults to all)
    @param start
        minimum GPS time for returned triggers
    @param end
        maximum GPS time for returned triggers
    @param kwargs UNDOCUMENTED

    @returns a `LIGO_LW` table containing the triggers
    """
    etg = _which_etg(etg)
    try:
        _etg_mod = __import__("laldetchar.triggers.%s" % etg.lower(),
                              fromlist=[""])
    except ImportError:
        raise ImportError("No module found for etg='%s'" % etg)
    try:
        _load_triggers = getattr(_etg_mod, "from_root")
    except AttributeError:
        raise AttributeError("No 'from_root' function for etg='%s'" % etg)
    out = _load_triggers(filepath, columns=columns, start=start, end=end,
                         **kwargs)
    return out

# close doxygen
##@}
