#!/usr/bin/python

# Copyright (C) 2012 Duncan Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Plot segments from file.

All files should be either SegWizard format, or LIGO_LW XML containining
both Segment and SegmentDef tables.

If LIGO_LW XML files are given, the flag names and instruments will be
extracted from the SegmentDefTable entries, otherwise no flag information
will be plotted.

Segment can be displayed versus time (default) or as a histogram.
"""

import warnings
warnings.filterwarnings("ignore", "column name (.*) is not lower case",
                        UserWarning)
warnings.filterwarnings("ignore", "Module dap was already imported from None",
                        UserWarning)

import os
import re
import optparse
import matplotlib
if not os.getenv("DISPLAY", None):
    import matplotlib
    matplotlib.use("agg", warn=False)

from glue.segments import (segment as Segment, segmentlist as SegmentList,
                           segmentlistdict as SegmentListDict)
from glue.segmentsUtils import fromsegwizard
from glue.ligolw import (table as ligolw_table, utils as ligolw_utils)
from glue.ligolw.lsctables import (SegmentTable, SegmentDefTable)

from pylal import (git_version, plotutils, plotsegments)
plotutils.set_rcParams()

__author__ = "Duncan Macleod <duncan.macleod@ligo.org>"
__version__ = git_version.id
__date__ = git_version.date

if __name__ == "__main__":

    usage = "%prog [options] [file1 file2 ...]"
    epilog = "If you are having problems, please contact daswg@ligo.org."
    parser = optparse.OptionParser(usage=usage, description=__doc__,\
                                   formatter=optparse.IndentedHelpFormatter(4),
                                   epilog=epilog,
                                   version=git_version.verbose_msg)
    parser.add_option("-o", "--output-file", action="store",\
                      default=False, help="output xml file, default: stdout")

    segopts = parser.add_option_group("Segment options",
                                      "Manipulate segments before plotting.")
    segopts.add_option("-G", "--g-offset", action="store", type="float",
                       default=0.0,
                       help="offset for all GEO600 segments, default: %default")
    segopts.add_option("-H", "--h-offset", action="store", type="float",
                       default=0.0,
                       help="offset for all LHO segments, default: %default")
    segopts.add_option("-L", "--l-offset", action="store", type="float",
                       default=0.0,
                       help="offset for all LLO segments, default: %default")
    segopts.add_option("-V", "--v-offset", action="store", type="float",
                       default=0.0,
                       help="offset for all Virgo segments, default: %default")

    plotopts = parser.add_option_group("Plot options")
    plotopts.add_option("-c", "--cumulative", action="store_true",
                        default=False,
                        help=("Plot the union of all segments "
                              "as a single list, deafult: %default"))
    plotopts.add_option("-i", "--cumulative-by-instrument", action="store_true",
                        default=False,
                        help=("Plot the union of all segments for each "
                              "instrument as a single list, deafult: %default"))
    plotopts.add_option("-b", "--tight-bbox", action="store_true",
                        default=False, help=("use tight bounding box for "
                                             "plot, default: %default"))
    plotopts.add_option("-T", "--title", action="store", type="string",
                        default="", help=("title for plot, default: None"))
    plotopts.add_option("-S", "--subtitle", action="store", type="string",
                        default="", help=("subtitle for plot, default: None"))

    tplotopts = parser.add_option_group("Segments versus time options",
                                        "Customise the default plot")
    tplotopts.add_option("-s", "--gps-start-time", action="store",
                         type="float", metavar="GPS", default=None,
                         help="GPS start time for plot")
    tplotopts.add_option("-e", "--gps-end-time", action="store",
                         type="float", metavar="GPS", default=None,
                         help="GPS end time for plot")
    tplotopts.add_option("-l", "--labels-inset", action="store_true",
                         default=False, help=("set segment labels inside "
                                              "the axes, default: %default"))

    histopts = parser.add_option_group(
                   "Segment histogram options",
                   "Histograms have not been implemented yet, apologies")

    # parse args
    opts, segfiles = parser.parse_args()

    # read segments
    xml = re.compile("(xml|xml.gz)\Z")
    segments = SegmentListDict()
    for fp in segfiles:
        if xml.search(fp):
            xmldoc = ligolw_utils.load_filename(fp, gz=fp.endswith(".gz"))
            seg_table = ligolw_table.get_table(xmldoc, SegmentTable.tableName)
            try:
                seg_table.getColumnByName("start_time_ns")
            except KeyError:
                seg_table.appendColumn("start_time_ns")
                seg_table.appendColumn("end_time_ns")
                for row in seg_table:
                     row.start_time_ns = row.end_time_ns = 0
            if opts.cumulative:
                for row in seg_table:
                    segments["_"].append(row.get())
            else:
                seg_def_table = ligolw_table.get_table(xmldoc,
                                                SegmentDefTable.tableName)
                flags = dict()
                for row in seg_def_table:
                    ifo = "".join(sorted(row.get_ifos()))
                    name = row.name
                    version = row.version
                    try:
                        flags[int(row.segment_def_id)] = (
                            "%s:%s:%d" % (ifo, name, version))
                    except TypeError:
                        flags[int(row.segment_def_id)] = "%s:%s" % (ifo, name)
                for row in seg_table:
                    if opts.cumulative_by_instrument:
                        flag = flags[int(row.segment_def_id)].split(":")[0]
                    else:
                        flag = flags[int(row.segment_def_id)]
                    try:
                        segments[flag].append(row.get())
                    except KeyError:
                        segments[flag] = SegmentList([row.get()])
        else:
            with open(fp, "r") as f:
                try:
                    segments["_"] += fromsegziward(f, coltype=LIGOTimeGPS)
                except KeyError:
                    segments["_"] = fromsegziward(f, coltype=LIGOTimeGPS)
    segments.coalesce()

    # apply offsets
    for key in segments.keys():
        site = ((re.match("[A-Z]1\Z", key) or re.match("[A-Z]1:", key)) and
                key[0].lower() or None)
        if site:
            segments.offsets.update({key:getattr(opts, "%s_offset" % site)})

    # plot segments
    bbox = opts.tight_bbox and "tight" or None
    histogram = False
    if histogram:
        pass # FIXME: implement segment histogram
    else:
        if opts.gps_start_time or opts.gps_end_time:
            xlim = [opts.gps_start_time, opts.gps_end_time]
        else:
            xlim = None
        plotsegments.plotsegmentlistdict(segments, opts.output_file,
                                         keys=sorted(segments.keys())[::-1],
                                         xlim=xlim, title=opts.title,
                                         subtitle=opts.subtitle,
                                         bbox_inches=bbox,
                                         insetlabels=opts.labels_inset)
