#!/usr/bin/python


# ============================================================================
#
#                               Preamble
#
# ============================================================================

from optparse import OptionParser
try:
    import sqlite3
except ImportError:
    # pre 2.5.x
    from pysqlite2 import dbapi2 as sqlite3
import sys, os, re
import numpy
import itertools

from glue import git_version
from glue import iterutils
from glue.ligolw import lsctables
from glue.ligolw import dbtables

from pylal import InspiralUtils
from pylal import CoincInspiralUtils
from pylal import ligolw_sqlutils as sqlutils

__author__ = "Collin Capano <cdcapano@physics.syr.edu>"
__prog__ = "ligolw_cbc_plotcumhist"

description = \
"Given a ranking stat, generates cumulative histograms of coincident events."

# ============================================================================
#
#                               Set Options
#
# ============================================================================

def parse_command_line():
    """
    Parser function dedicated
    """
    parser = OptionParser(
        version = git_version.verbose_msg,
        usage   = "%prog [options]",
        description = description
        )
    # following are related to file input and output naming
    parser.add_option( "-i", "--input", action = "store", type = "string", default = None,
        help = 
            "Input database to read. Can only input one at a time."
        )
    parser.add_option( "-t", "--tmp-space", action = "store", type = "string", default = None,
        metavar = "PATH",
        help = 
            "Location of local disk on which to do work. This is optional; " +
            "it is only used to enhance performance in a networked " +
            "environment. "
        )
    parser.add_option( "-P", "--output-path", action = "store", type = "string", \
        default = os.getcwd(), metavar = "PATH", \
        help = 
            "Optional. Path where the figures should be stored. Default is current directory." 
        )
    parser.add_option( "-O", "--enable-output", action = "store_true", \
        default =  False, metavar = "OUTPUT", \
        help = 
            "enable the generation of html and cache documents" 
        )
    parser.add_option( "", "--coinc-table", action = "store", type = "string", \
        default = None, \
        help =
            "Required. Can be any table with a coinc_event_id and a time column. Ex. coinc_inspiral."
        )
    parser.add_option( "-p", "--plot-playground-only", action = "store_true",
        default = False,
        help =
            "Plot only playground data. Use if plotting un-opened box."
        )
    parser.add_option( "-u", "--user-tag", action = "store", type = "string",
        default = None, metavar = "USERTAG",
        help =
            "Add a tag in the name of the figures"
        )
    parser.add_option( "-s", "--show-plot", action = "store_true", default = False, \
        help = 
            "display the plots on the terminal" 
        )
    parser.add_option( "", "--num-bins", action = "store", type = "int", default = 20,
        help =
            "The number of bins to histogram the stats into. Default is 20."
        )
    parser.add_option( "", "--ranking-stat", action = "store", type = "string", default = None,
        help =
            "Required. The stat to use to rank triggers for plotting. " +
            "This can be any column in the coinc_inspiral  table."
        )
    parser.add_option( "", "--rank-by", action = "store", type = "string", default = None, 
        metavar = "MAX or MIN",
        help = 
            "Requried. Options are MAX or MIN. " +
            "This specifies whether to rank triggers by ascending (MAX) or " +
            "descending (MIN) stat value." 
        )
    parser.add_option( "", "--square-stats", action = "store_true", default = False,
        help =
            "Square the ranking-stat before plotting."
        )
    parser.add_option( "", "--param-name", action = "store", default = None,
        metavar = "PARAMETER", 
        help = 
            "Can be any parameter in the coinc_inspiral table. " + 
            "Specifying this and param-ranges will generate cumulative histogram plots " +
            "broken-down by the param-bins (combined-plots are always generated)."
        )
    parser.add_option( "", "--param-ranges", action = "store", default = None,
        metavar = " [ LOW1, HIGH1 ); ( LOW2, HIGH2]; etc.",
        help = 
            "Requires --param-name. Specify the parameter ranges " +
            "to bin the triggers in. A '(' or ')' implies an open " +
            "boundary, a '[' or ']' a closed boundary. To specify " +
            "multiple ranges, separate each range by a ';'."
        )
    parser.add_option( "-v", "--verbose", action = "store_true", default = False, \
        help = 
            "print information to stdout" 
        )

    (options,args) = parser.parse_args()

    #check if required options specified and for self-consistency
    if not options.input:
        raise ValueError, "--input must be specified"

    return options, sys.argv[1:]
            

# =============================================================================
#
#                       Function Definitions
#
# =============================================================================

def plotcumhist(figure_number, zero_lag_list, slide_dict, nbins, ranking_stat,
    rank_by, bkg_correction = 1., frg_marker = '^', frg_color = 'b', title_txt = ''):

        figure(figure_number)
        y_min = 0.8
        if title_txt != '':
            title(title_txt, size = 'x-large')

        # check if have stats; if not, create an empty plot
        if zero_lag_list == [] and not any(slide_dict.values()):
            # FIXME: Create empty plot
            text( 0.5, 0.5, "No zero-lag or time-slide data to plot.", 
                ha = 'center', va = 'center' )
            xlabel( r"No data", size='x-large' )
            ylabel( r"No data", size='x-large' )

        else:
            # get min/max values
            min_val = min( zero_lag_list + list(iterutils.flatten(slide_dict.values())) )
            max_val = max( zero_lag_list + list(iterutils.flatten(slide_dict.values())) )

            # create the bins
            bins = numpy.linspace(min_val, max_val, nbins+1, endpoint = True)
            ds = (bins[1] - bins[0]) / 2
            xvals = bins[:-1] + ds

            # hist of the zero lag
            if zero_lag_list != []:
                zero_dist, _ = numpy.histogram(zero_lag_list, bins)
                if rank_by == 'MIN':
                    cum_dist_zero = zero_dist.cumsum()
                else:
                    cum_dist_zero = zero_dist[::-1].cumsum()[::-1]
                semilogy(xvals, cum_dist_zero+0.00001,
                        marker = frg_marker,
                        markerfacecolor = frg_color,
                        markeredgecolor = 'k',
                        linestyle = 'None',
                        markersize=12)

            # hist of the slides
            if any(slide_dict.values()):
                cum_dist_slide = []
                for slide_list in slide_dict.values():
                    num_slide, _ = numpy.histogram(slide_list, bins)
                    if rank_by == 'MIN':
                        cum_dist_slide.append( num_slide.cumsum() )
                    else:
                        cum_dist_slide.append( num_slide[::-1].cumsum()[::-1] )
                cum_dist_slide = numpy.array(cum_dist_slide)
                slide_mean = cum_dist_slide.mean(axis=0) * bkg_correction
                slide_std = cum_dist_slide.std(axis=0) * numpy.sqrt(bkg_correction)
                slide_min = []
                for i in range( len(slide_mean) ):
                    slide_min.append( max(slide_mean[i] - slide_std[i], 0.00001) )
                    slide_mean[i] = max(slide_mean[i], 0.00001)
                semilogy(xvals, numpy.asarray(slide_mean),
                    marker = '+',
                    markerfacecolor = 'k',
                    markeredgecolor = 'k',
                    linestyle = 'None',
                    markersize = 12)
                tmpx, tmpy = makesteps( xvals, slide_min, slide_mean + slide_std )
                fill( (tmpx-ds), tmpy,
                    facecolor = 'y',
                    alpha = 0.3 )
                y_min = min(slide_mean)

            # set limits, labels, and titles
            ylim( ymin = 0.5*y_min, ymax = max(gca().get_ylim()[1], 1.2) )
            # flip limits for rank-by MIN
            if rank_by == 'MIN':
                xlim( gca().get_xlim()[::-1] )
            xlabel(re.sub('_', ' ', ranking_stat), size = 'x-large')
            ylabel('Number of Events', size = 'x-large')
            grid(True)

            

# ============================================================================
#
#                                 Main
#
# ============================================================================

#
#   Generic Initialization
#

# parse command line
opts, args = parse_command_line()

# get input database filename
filename = opts.input
if not os.path.isfile( filename ):
    raise ValueError, "The input file, %s, cannot be found." % filename

# Setup working databases and connections
if opts.verbose: 
    print >> sys.stdout, "Creating a database connection..."
working_filename = dbtables.get_connection_filename( 
    filename, tmp_path = opts.tmp_space, verbose = opts.verbose )
connection = sqlite3.connect( working_filename )
if opts.tmp_space:
    dbtables.set_temp_store_directory(connection, opts.tmp_space, verbose = opts.verbose)

coinc_table = sqlutils.validate_option( opts.coinc_table )

#
#   Plotting Initialization
#

# Change to Agg back-end if show() will not be called 
# thus avoiding display problem
if not opts.show_plot:
  from matplotlib import use
  use('Agg')
from matplotlib.pyplot import *
from pylal.viz import makesteps
rc('text', usetex=True)

#
#   Program-specific Initialization
#

# get number of bins
nbins = opts.num_bins

# Get ranking stat and rank-by
if " " in opts.ranking_stat.strip():
    raise ValueError, "ranking-stat must not contain spaces"
ranking_stat =  opts.ranking_stat.strip()
if coinc_table+'.' not in ranking_stat:
    ranking_stat = '.'.join([coinc_table, ranking_stat])
# set plot-stat: this is just for plot titles and file naming
plot_stat_name = ranking_stat.replace(coinc_table+'.', '')
if opts.square_stats:
    plot_stat_name += '$^{2}$'

rank_by = opts.rank_by.strip().upper()
if rank_by != 'MAX' and rank_by != 'MIN':
    raise ValueError, "rank-by must be set to MAX or MIN"

# get param-ranges
if opts.param_name:
    trigsymbols = itertools.cycle(( 'v', 'o', 's' ))
    param_parser = sqlutils.parse_param_ranges( coinc_table, opts.param_name,
        opts.param_ranges, verbose = opts.verbose )
    param_name = param_parser.param.split('.')[1]
    connection.create_function("group_by_param", 1, param_parser.group_by_param_range)
    param_grouping = ''.join([ 'group_by_param(', param_parser.param, ')' ])
    param_ranges = {}
    for (n, (low, high)), marker in zip(enumerate(param_parser.param_ranges), trigsymbols):
        if low[0] == '>=':
            lowbnd = '['
        elif low[0] == '>':
            lowbnd = '('
        if high[0] == '<=':
            highbnd = ']'
        elif high[0] == '<':
            highbnd = ')'
        param_ranges[(param_name, n)] = {}
        param_ranges[(param_name, n)]['range'] = \
            ''.join([lowbnd, str(low[1]), ',', str(high[1]), highbnd])
        param_ranges[(param_name, n)]['marker'] = marker

else:
    param_grouping = '0'
    param_name = 'no_binning'
    param_ranges = { (param_name,0):{} }

# Get all the experiment_ids
sqlquery = """
    SELECT
        experiment.experiment_id,
        experiment.gps_start_time,
        experiment.gps_end_time,
        experiment.instruments
    FROM
        experiment
    """
for this_eid, gps_start_time, gps_end_time, on_instruments in connection.cursor().execute(sqlquery):
    on_instruments = lsctables.instrument_set_from_ifos(on_instruments)
    on_instr = r','.join(sorted(on_instruments)) + r' Time:'
    # figure out all possible ifo combinations and param_groupings for this experiment
    categories = [ (frozenset(ifos), param_group)
        for ifos in CoincInspiralUtils.get_ifo_combos(list(on_instruments))
        for param_group in range(len(param_ranges.keys())) ]

    if opts.verbose:
        print >> sys.stdout, "Creating plots for %s" % on_instr
        print >> sys.stdout, "\tgetting data..."

    #
    #   Gather data
    #
    zero_lag_stats = {}
    slide_stats = {}
    durations = {}
    sqlquery = ''.join(["""
        SELECT
            experiment_summary.experiment_summ_id,
            experiment_summary.duration,
            experiment_summary.datatype,
            """, coinc_table, """.ifos,
            """, param_grouping, """,
            """, ranking_stat, """
        FROM
            experiment_summary
        JOIN
            """, coinc_table, """, experiment_map ON (
                experiment_summary.experiment_summ_id == experiment_map.experiment_summ_id
                AND experiment_map.coinc_event_id == """, coinc_table, """.coinc_event_id )
        WHERE
            experiment_summary.datatype != "simulation"
            AND experiment_summary.experiment_id == """, '"', this_eid, '"' ])
    for esid, duration, datatype, ifos, param_group, stat in connection.cursor().execute(sqlquery):
        # skip if not in desired param ranges
        if param_group is None:
            continue
        ifos = frozenset(lsctables.instrument_set_from_ifos(ifos))
        if opts.square_stats:
            stat = stat**2.
        if datatype != 'slide':
            durations[datatype] = float(duration)
            if opts.plot_playground_only and (datatype == 'all_data' or datatype == 'exclude_play'):
                continue
            if datatype not in zero_lag_stats:
                zero_lag_stats[datatype] = dict( [category, []] for category in categories ) 
            zero_lag_stats[datatype][(ifos, param_group)].append(stat)
        else:
            if esid not in slide_stats:
                slide_stats[esid] = {}
            if (ifos, param_group) not in slide_stats[esid]:
                slide_stats[esid][(ifos, param_group)] = []
            slide_stats[esid][(ifos, param_group)].append(stat)
    if len(zero_lag_stats)==0:
        # create a closed box plot with an empty playground to show the bkg distribution
        zero_lag_stats[u'playground'] = dict( [category, []] for category in categories )

    #
    #   Plotting
    #

    # set InspiralUtils options for file and plot naming
    opts.gps_start_time = gps_start_time
    opts.gps_end_time = gps_end_time
    opts.ifo_times = ''.join(sorted(on_instruments))
    opts.ifo_tag = ''
    InspiralUtilsOpts = InspiralUtils.initialise( opts, __prog__, git_version.verbose_msg )

    fnameList = []
    tagList = []
    figure_numbers = []
    fig_num = 0
    
    # cycle over datatypes, creating plots for each
    if opts.verbose:
        print >> sys.stdout, "\tgenerating plot figures for:"
    for datatype in zero_lag_stats:
       
        if opts.verbose:
            print >> sys.stdout, "\t\t%s..." % datatype
        # set open box flag
        open_box = (datatype == 'all_data') or (datatype == 'exclude_play')

        # set bkg_correction
        warn_msg = ''
        if 'all_data' in durations and durations['all_data'] != 0:
            bkg_correction = durations[datatype] / durations['all_data']
        else:
            warn_msg = 'Warning: No all-data time found. Background may not be scaled correctly.'
            bkg_correction = 1.

        #
        # combined plot
        #
        fig_num += 1
        figure_numbers.append(fig_num)
        zero_lag_list = [ val for vallist in zero_lag_stats[datatype].values() for val in vallist ]
        slide_dict = {}
        for esid in slide_stats:
            slide_dict[esid] = [ val for vallist in slide_stats[esid].values() for val in vallist ]

        title_txt = ' '.join([ on_instr, re.sub('_', '-', datatype.upper()), "Cumulative Histogram of", 
            re.sub('_', ' ', plot_stat_name) ]) 

        plotcumhist( fig_num, zero_lag_list, slide_dict, opts.num_bins, plot_stat_name, rank_by,
            bkg_correction = bkg_correction, frg_marker = '^', frg_color = 'b',
            title_txt = title_txt )

        if warn_msg != '':
            text( 0,0, warn_msg, 
                va = 'bottom', ha = 'left',
                transform = figure(fig_num).gca().transAxes )

        if opts.enable_output:
            name = InspiralUtils.set_figure_tag( '_'.join([ "cumhist_combined",
                plot_stat_name.replace('*', '_x_').replace('/', '_div_').replace('-', '_sub_').replace('$^{2}$', '_sq') ]),
                datatype_plotted = datatype.upper(), open_box = open_box)
            fname = InspiralUtils.set_figure_name(InspiralUtilsOpts, name)
            fname_thumb = InspiralUtils.savefig_pylal( filename=fname )
            fnameList.append(fname)
            tagList.append(name)

        #
        # plots by coincidence and param-bin (only do if there is more then one category)
        #
        if len(categories) > 1:
            for (ifos, param_group), zero_lag_list in zero_lag_stats[datatype].items():
                fig_num += 1
                figure_numbers.append(fig_num)
                # construct the right slide-dict
                slide_dict = dict([ [esid, slide_stats[esid][(ifos, param_group)]] for esid in slide_stats
                    if (ifos, param_group) in slide_stats[esid] ])
                # set marker, title
                if param_name == 'no_binning':
                    title_lbl = r','.join(sorted(ifos))
                    frg_marker = '^'
                else:
                    title_lbl = ' '.join([ r','.join(sorted(ifos)), param_name,
                        param_ranges[(param_name, param_group)]['range'] ])
                    frg_marker = param_ranges[(param_name, param_group)]['marker']
                title_txt = ' '.join([ on_instr, datatype.upper(), title_lbl ])
                # remove underscores to prevent errors
                title_txt = re.sub('_', '-', title_txt)
    
                plotcumhist( fig_num, zero_lag_list, slide_dict, opts.num_bins, plot_stat_name, rank_by,
                    bkg_correction = bkg_correction, frg_marker = frg_marker,
                    frg_color = InspiralUtils.get_coinc_ifo_colors( ifos ),
                    title_txt = title_txt )
    
                if warn_msg != '':
                    text( 0,0, warn_msg, 
                        va = 'bottom', ha = 'left',
                        transform = figure(fig_num).gca().transAxes )
        
                if opts.enable_output:
                    param_tag = param_name
                    if opts.param_ranges:
                        param_tag = '_'.join([ param_tag,
                            str(param_parser.param_ranges[param_group][0][1]),
                            str(param_parser.param_ranges[param_group][1][1]) ])
                    plot_description = '_'.join([ param_tag, ''.join(sorted(ifos)), "cumhist",
                        plot_stat_name.replace('*', 'x').replace('/', '_div_').replace('-', '_sub_').replace('$^{2}$', '_sq') ])
                    name = InspiralUtils.set_figure_tag( plot_description, 
                        datatype_plotted = datatype.upper(), open_box = open_box )
                    fname = InspiralUtils.set_figure_name(InspiralUtilsOpts, name)
                    fname_thumb = InspiralUtils.savefig_pylal( filename=fname )
                    fnameList.append(fname)
                    tagList.append(name)


    #
    #   Create the html page for this experiment id
    #
    
    if opts.enable_output:
        if opts.verbose:
            print >> sys.stdout, "\twriting html file and cache."

        # create html of closed box plots
        closed_list = [fname for fname in fnameList if 'OPEN_BOX' not in fname]
        closed_html = InspiralUtils.write_html_output( InspiralUtilsOpts, args, closed_list,
            tagList, add_box_flag = True )
        InspiralUtils.write_cache_output( InspiralUtilsOpts, closed_html, closed_list )

        # create html with open and closed box plots in them
        if not opts.plot_playground_only:
            open_html = InspiralUtils.write_html_output( InspiralUtilsOpts, args, fnameList,
                tagList, add_box_flag = True )
            InspiralUtils.write_cache_output( InspiralUtilsOpts, open_html, fnameList )
    
    if opts.show_plot:
        show()

    #
    # Close the figures and clear memory for the next instrument time
    #

    for number in figure_numbers:
        close(number)

    # clean memory
    del zero_lag_stats
    del slide_stats
    del durations

#
#   Finished cycling over experiments; exit
#
connection.close()
dbtables.discard_connection_filename( filename, working_filename, verbose = opts.verbose)

if opts.verbose:
    print >> sys.stdout, "Finished!"
sys.exit(0)
