#!/usr/bin/python

#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

import os, sys, copy, time
import numpy
from optparse import OptionParser

from glue import lal
from glue import git_version
from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import utils
from glue.ligolw import lsctables
from glue.ligolw.utils import process
from glue.ligolw.utils.print_tables import print_tables
from glue import segments

from pylal.xlal.date import XLALGPSToUTC
try:
    from pylal.xlal.datatypes.ligotimegps import LIGOTimeGPS
except ImportError:
    # s6 code
    from pylal.xlal.date import LIGOTimeGPS

__prog__ = "lalapps_cbc_print_rs"
__author__ = "Collin Capano <cdcapano@physics.syr.edu>"

description = \
"Prints information about rate statistics."

# =============================================================================
#
#                                   Set Options
#
# =============================================================================


def parse_command_line():
    """
    Parse the command line, return options and check for consistency among the
    options.
    """
    parser = OptionParser( 
        version = git_version.verbose_msg,
        usage   = "%prog --output-format [options] rate_file.xml1 rate_file.xml2 ...",
        description = description )
    
    parser.add_option("-f", "--output-format", action = "store", type = "string", default = None,
        metavar = "wiki, html, OR xml",
        help =
            "Required. Output format to print tables as. Options are 'wiki, 'html', or 'xml'."
            )
    parser.add_option("-o", "--output", action = "store", type = "string", default = None,
        help =
            "File to save output to. If none specified, will print to stdout."
            )
    parser.add_option( "-r", "--rate-stats", action = "store_true", default = False,
        help =
            "Print the rate statistics table."
            )
    parser.add_option( "-g", "--glitchiest-files", action = "store_true", default = False,
        help =
            "Print glitchiest files table. Requires --selection-criteria."
            )
    parser.add_option( "-q", "--quietest-files", action = "store_true", default = False,
        help =
            "Print quietest files table. Requires --selection-criteria."
            )
    parser.add_option( "-s", "--selection-criteria", action = "store", default = None,
        help =
            "What stat to select files by for glitchiest/quietest table. Options are " +
            "trigger_rate, trigger_rate_per_tmplt, average_snr, max_snr, median_snr, or snr_standard_deviation."
            )
    parser.add_option("-n", "--num-files", action = "store", type = "int", default = 5,
        help =
            "The number of glitchiest/quietest files per file-type to print info about. Default is 5."
            )
    parser.add_option("-m", "--merge-with-loudest-events", action = "store_true", default = False,
        help =
            "Requires --le-xml and --full-xml. Will select segments from the file_rank table " +
            "that intersects with the coincident events listed in the loudest_events table " +
            "in --le-xml. Information about the segments will be merged with information from " +
            "the loudest events table."
            )
    parser.add_option("-l", "--le-xml", action = "store", type = "string", default = None,
        help =
            "An xml file containing a loudest events table with which to select segments."
            )
    parser.add_option("-x", "--full-xml", action = "store", type = "string", default = None,
        help =
            "An xml file containing all of the lsctables to get needed additional information " +
            "about the loudest events used for selection."
            )

    (options, filenames) = parser.parse_args()

    if not options.output_format:
        raise ValueError, "--output-format is a required option"
    if (options.glitchiest_files or options.quietest_files) and not options.selection_criteria:
        raise ValueError, "--glitchiest-files and --quietest files require --selection-criteria"
    if options.merge_with_loudest_events and not (options.le_xml and options.full_xml):
        raise ValueError, "--merge-with-loudest-events require --le-xml and --full-xml"

    return options, filenames


# =============================================================================
#
#                       Function Definitions
#
# =============================================================================

class SelectedSegmentsTable(table.Table):
    tableName = "selected_segments:table"
    validcolumns = {
        "coinc_event_rank":"int_4u",
        "coinc_event_ifos":"lstring",
        "coinc_event_combined_far":"real_8",
        "coinc_event_end_time":"int_4s",
        "coinc_event_mchirp":"real_8",
        "coinc_event_mtotal":"real_8",
        "sngl_event_ifo":"lstring",
        "sngl_event_unslid_time_utc":"lstring",
        "sngl_event_unslid_time":"int_4s",
        "sngl_event_unslid_time_ns":"int_4s",
        "file_type":"lstring",
        "segment_start_time_utc":"lstring",
        "segment_duration":"int_4u",
        "num_triggers":"int_4u",
        "trigger_rate":"real_8",
        "trigger_rate_rank":"int_4u",
        "num_tmplts":"int_4u",
        "trigger_rate_per_tmplt":"real_8",
        "trigger_rate_per_tmplt_rank":"int_4u",
        "average_snr":"real_8",
        "average_snr_rank":"real_8",
        "max_snr":"real_8",
        "max_snr_rank":"real_8",
        "median_snr":"real_8",
        "median_snr_rank":"real_8",
        "snr_standard_deviation":"real_8",
        "snr_standard_deviation_rank":"real_8",
        "daily_ihope_page":"lstring",
        "elog_page":"lstring"
        }
        
class SelectedSegments(object):
    __slots__ = SelectedSegmentsTable.validcolumns.keys()

    def get_pyvalue(self):
        if self.value is None:
            return None
        return ligolwtypes.ToPyType[self.type or "lstring"](self.value)

SelectedSegmentsTable.RowType = SelectedSegments


# =============================================================================
#
#                                     Main
#
# =============================================================================
    
opts, filenames = parse_command_line() 


# Also treat the special name 'stdin' as stdin
filenames = map(lambda x: x != 'stdin' and x or None, filenames)

# setup the input rate doc
rate_doc = ligolw.Document()
rate_doc.appendChild(ligolw.LIGO_LW())
for n, filename in enumerate(filenames):
    thisdoc = utils.load_filename( filename, gz = (filename or "stdin").endswith(".gz") )
    this_rate_table = table.get_table(thisdoc, "rate_statistics:table")
    this_seg_table = table.get_table(thisdoc, "file_statistics:table")
    if n == 0:
        rate_doc.childNodes[0].appendChild(this_rate_table)
        rate_doc.childNodes[0].appendChild(this_seg_table)
        rate_table = table.get_table(rate_doc, "rate_statistics:table")
        seg_table = table.get_table(rate_doc, "file_statistics:table")
    else:
        rate_table.extend(row for row in this_rate_table)
        seg_table.extend(row for row in this_seg_table)
                
# setup the output doc
out_doc = ligolw.Document()
# setup the LIGOLW tag
out_doc.appendChild(ligolw.LIGO_LW())
# add this program's metadata
proc_id = process.register_to_xmldoc(out_doc, __prog__, opts.__dict__, version = git_version.id)

# set the output to sys.stdout if None specified
if opts.output is None:
    opts.output = sys.stdout

if opts.rate_stats:
    out_doc.childNodes[0].appendChild(rate_table)

if opts.glitchiest_files:
    # define the table
    class GlitchiestTable(table.Table):
        tableName = "glitchiest_files_by_" + opts.selection_criteria + ":table"
        validcolumns = dict([ [col_name, col_type] for col_name, col_type in zip(seg_table.columnnames, seg_table.columntypes) ])
    glitch_table = lsctables.New(GlitchiestTable)
    # add to out_doc
    out_doc.childNodes[0].appendChild(glitch_table)
    ifos = sorted(set([row.ifo for row in seg_table]))
    filetypes = [row.file_type for row in seg_table]
    filetypes = sorted(set(filetypes), key=filetypes.index)
    glitch_table += sorted( [row for row in seg_table if getattr(row, opts.selection_criteria+'_rank') <= opts.num_files],
        key = lambda row: (ifos.index(row.ifo), filetypes.index(row.file_type), getattr(row, opts.selection_criteria+'_rank')) )

if opts.quietest_files:
    # define the table
    class QuietestTable(table.Table):
        tableName = "quietest_files_by_" + opts.selection_criteria + ":table"
        validcolumns = dict([ [col_name, col_type] for col_name, col_type in zip(seg_table.columnnames, seg_table.columntypes) ])
    quiet_table = lsctables.New(QuietestTable)
    # add to out_doc
    out_doc.childNodes[0].appendChild(quiet_table)
    ifos = sorted(set([row.ifo for row in seg_table]))
    filetypes = [row.file_type for row in seg_table]
    filetypes = sorted(set(filetypes), key=filetypes.index)
    quiet_table += sorted( [row for row in seg_table if getattr(row, opts.selection_criteria+'_rank') > row.num_files - opts.num_files],
        key = lambda row: (ifos.index(row.ifo), filetypes.index(row.file_type), row.num_files - getattr(row, opts.selection_criteria+'_rank')) )

if opts.merge_with_loudest_events:
    # create the table
    selected_segs_table = lsctables.New(SelectedSegmentsTable)
    # add to out_doc
    out_doc.childNodes[0].appendChild(selected_segs_table)

    # get the needed docs
    letable_doc = utils.load_filename( opts.le_xml, gz = opts.le_xml.endswith(".gz") )
    fullxml_doc = utils.load_filename( opts.full_xml, gz = opts.full_xml.endswith(".gz") )

    # get the segments corresponding to the loudest events end times
    letable = table.get_table(letable_doc, "loudest_events:table")
    sngl_insp_table = table.get_table(fullxml_doc, lsctables.SnglInspiralTable.tableName)
    coinc_map_table = table.get_table(fullxml_doc, lsctables.CoincMapTable.tableName)
    exp_summ_table = table.get_table(fullxml_doc, lsctables.ExperimentSummaryTable.tableName)
    exp_map_table = table.get_table(fullxml_doc, lsctables.ExperimentMapTable.tableName)
    offset_vectors = table.get_table(fullxml_doc, lsctables.TimeSlideTable.tableName).as_dict()

    # create a segmentlist out of the segments in the rank_table
    segdict = dict([ [(row.file_type, row.ifo, segments.segment(row.out_start_time, (row.out_start_time+row.segment_duration))), row] for row in seg_table ])

    # get lerank name
    lerankname = [col.getAttribute("Name").split(":")[-1] for col in letable.getElementsByTagName(u'Column') if "rank" in col.getAttribute("Name")][0]

    for lerow in letable:
        ceid = lerow.coinc_event_id
        # get the time_slide_id
        esids = [row.experiment_summ_id for row in exp_map_table if row.coinc_event_id == ceid]
        tsids = set([row.time_slide_id for row in exp_summ_table if row.experiment_summ_id in esids])
        if len(tsids) > 1:
            raise ValueError, "multiple time slide ids found for coinc event %s" % str(ceid)
        offset_vect = offset_vectors[tsids.pop()]
        # get the end times of each ifo
        event_ids = [row.event_id for row in coinc_map_table if row.coinc_event_id == ceid and row.table_name == "sngl_inspiral"]
        event_info = dict([ [row.ifo, row] for row in sngl_insp_table if row.event_id in event_ids ])
        selected_segs = {}
        non_slid_times = {}
        for ifo, sngl_event in event_info.items():
            non_slid_times[ifo] = end_time = sngl_event.end_time + offset_vect[ifo]
            selected_segs[ifo] = dict([ [filetype, row] for(filetype, seg_ifo, seg), row in segdict.items() if seg_ifo == ifo and end_time in seg ])
        filetypes = set([filetype for ifo in selected_segs.keys() for filetype in selected_segs[ifo].keys()])
        for filetype in filetypes:
            ifos = sorted([ ifo for ifo in selected_segs.keys() if filetype in selected_segs[ifo] ])
            for ifo in ifos:
                seg = selected_segs[ifo][filetype]
                selected_seg = SelectedSegments()
                selected_seg.coinc_event_rank = getattr(lerow, lerankname)
                selected_seg.coinc_event_end_time = lerow.end_time
                selected_seg.coinc_event_ifos = lerow.ifos
                selected_seg.coinc_event_combined_far = lerow.combined_far
                selected_seg.coinc_event_mchirp = lerow.mchirp
                selected_seg.coinc_event_mtotal = lerow.mass
                selected_seg.sngl_event_ifo = ifo
                selected_seg.sngl_event_unslid_time = non_slid_times[ifo]
                selected_seg.sngl_event_unslid_time_ns = event_info[ifo].end_time_ns
                sngl_end_time_utc = XLALGPSToUTC(LIGOTimeGPS(non_slid_times[ifo], 0))
                selected_seg.sngl_event_unslid_time_utc = time.strftime("%a %d %b %Y %H:%M:%S", sngl_end_time_utc)
                selected_seg.file_type = filetype
                selected_seg.segment_start_time_utc = seg.out_start_time_utc
                selected_seg.segment_duration = seg.segment_duration
                selected_seg.num_triggers = seg.num_triggers
                selected_seg.trigger_rate = seg.trigger_rate
                selected_seg.trigger_rate_rank = seg.trigger_rate_rank
                selected_seg.num_tmplts = seg.num_tmplts
                selected_seg.trigger_rate_per_tmplt = seg.trigger_rate_per_tmplt
                selected_seg.trigger_rate_per_tmplt_rank = seg.trigger_rate_per_tmplt_rank
                selected_seg.average_snr = seg.average_snr
                selected_seg.average_snr_rank = seg.average_snr_rank
                selected_seg.max_snr = seg.max_snr
                selected_seg.max_snr_rank = seg.max_snr_rank
                selected_seg.median_snr = seg.median_snr
                selected_seg.median_snr_rank = seg.median_snr_rank
                selected_seg.snr_standard_deviation = seg.snr_standard_deviation
                selected_seg.snr_standard_deviation_rank = seg.snr_standard_deviation_rank
                selected_seg.daily_ihope_page = seg.daily_ihope_page
                selected_seg.elog_page = seg.elog_page
                selected_segs_table.append(selected_seg)


# save the results
if opts.output_format == 'xml':
    utils.write_filename( out_doc, opts.output, xsl_file = "ligolw.xsl")

else:
    if opts.output != sys.stdout:
        opts.output = open(opts.output, 'w')
    tableList = []
    columnList = []
    if opts.rate_stats: 
        thislist = [ 
            "file_type",
            "ifo",
            "average_rate",
            "average_rate_per_tmplt",
            "max_rate",
            "min_rate",
            "median_rate",
            "standard_deviation",
            "average_num_tmplts"]
        columnList.extend( ':'.join(["rate_statistics",colname]) for colname in thislist )
        tableList.append('rate_statistics')
    if opts.glitchiest_files:
        thislist = [ 
            "ifo",
            "file_type",
            "file_name",
            "trigger_rate_rank",
            "trigger_rate_per_tmplt_rank",
            "average_snr_rank",
            "max_snr_rank",
            "median_snr_rank",
            "snr_standard_deviation_rank",
            "out_start_time_utc",
            "segment_duration",
            "num_events",
            "trigger_rate",
            "num_tmplts",
            "trigger_rate_per_tmplt",
            "average_snr",
            "max_snr",
            "median_snr",
            "snr_standard_deviation",
            "elog_page",
            "daily_ihope_page"]
        columnList.extend( ':'.join(['glitchiest_files_by_'+opts.selection_criteria, colname]) for colname in thislist )
        tableList.append('glitchiest_files_by_'+opts.selection_criteria)
    if opts.quietest_files:
        thislist = [ 
            "ifo",
            "file_type",
            "file_name",
            "trigger_rate_rank",
            "trigger_rate_per_tmplt_rank",
            "average_snr_rank",
            "max_snr_rank",
            "median_snr_rank",
            "snr_standard_deviation_rank",
            "out_start_time_utc",
            "segment_duration",
            "num_events",
            "trigger_rate",
            "num_tmplts",
            "trigger_rate_per_tmplt",
            "average_snr",
            "max_snr",
            "median_snr",
            "snr_standard_deviation",
            "elog_page",
            "daily_ihope_page"]
        columnList.extend( ':'.join(['quietest_files_by_'+opts.selection_criteria, colname]) for colname in thislist )
        tableList.append('quietest_files_by_'+opts.selection_criteria)
    if opts.merge_with_loudest_events:
        thislist = [
            'coinc_event_rank',
            'coinc_event_end_time',
            'coinc_event_ifos',
            'coinc_event_combined_far',
            'coinc_event_mchirp',
            'coinc_event_mtotal',
            'file_type',
            'sngl_event_ifo',
            'sngl_event_unslid_time_utc',
            'out_start_time_utc',
            'segment_duration',
            'num_events',
            'trigger_rate',
            'trigger_rate_rank',
            'num_tmplts',
            'trigger_rate_per_tmplt',
            'trigger_rate_per_tmplt_rank',
            'average_snr',
            'average_snr_rank',
            'max_snr',
            'max_snr_rank',
            'median_snr',
            'median_snr_rank',
            'snr_standard_deviation',
            'snr_standard_deviation_rank',
            'elog_page',
            'daily_ihope_page']
        columnList.extend( ':'.join(['selected_segments',colname]) for colname in thislist )
        tableList.append('selected_segments')

    row_span_columns = [
        'ifo',
        'file_type',
        'glitchiest_files_by_'+opts.selection_criteria+':elog_page',
        'quietest_files_by_'+opts.selection_criteria+':elog_page',
        'daily_ihope_page',
        'coinc_event_rank',
        'coinc_event_end_time',
        'coinc_event_ifos',
        'coinc_event_combined_far',
        'coinc_event_mchirp',
        'coinc_event_mtotal']
    rspan_break_columns = [
        'glitchiest_files_by_'+opts.selection_criteria+':ifo',
        'quietest_files_by_'+opts.selection_criteria+':ifo',
        'coinc_event_end_time',
        'coinc_event_time',
        'coinc_event_combined_far',
        'coinc_event_mchirp',
        'coinc_event_mtotal' ]
    print_tables(out_doc, opts.output, opts.output_format, tableList = tableList, columnList = columnList,
        round_floats = True, decimal_places = 2, format_links = True, title = None, print_table_names = True,
        row_span_columns = row_span_columns, rspan_break_columns = rspan_break_columns)

    opts.output.close()

sys.exit(0)

