#!/usr/bin/env python
#
# Copyright (C) 2013  Branson Stephens
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

import pygtk
pygtk.require("2.0")
import pygst
pygst.require("0.10")
import gobject
import gst

from gstlal import simplehandler
from gstlal import pipeparts
from gstlal import datasource
from optparse import OptionParser, Option
import math
import glob
import gzip
import os
from collections import deque
from glue import segments
import numpy as np

#
# =============================================================================
#
#                                   Utilities
#
# =============================================================================
#

# Older versions of numpy.load choke on .gz files.
# Not sure what version fixes this, but the numpy package 
# for debian squeeze (1.4.1) doesn't. See
# https://github.com/numpy/numpy/issues/1593

def wrapNpLoad(filePath):
    if filePath.endswith('gz'):
        fileObj = gzip.open(filePath)
        return np.load(fileObj)
    else:
        return np.load(filePath)

# A class for handling the veto timeseries files.

class vetoSource:
    def __init__(self, inputPath, inputPre, inputExt):
        # get the file list
        pathPrefix = os.path.join(inputPath,inputPre)
        self.pathPrefix = pathPrefix
        self.inputPre = inputPre
        self.inputExt = inputExt
        filePathList = self.check_for_new_files()
        self.fileQueue = deque(filePathList)
        # This is the offset of a given buffer with respect to the global stream in 
        # units of *samples*.
        self.current_offset = 0
    
        # Determine the rate by looking at the first file in the list.
        # XXX If the files with a given prefix have different rates, this will break.
        # Assume file names are of the form:
        # <input_prefix><gpsstart>-<duration><input_ext>
        filePath = filePathList[0]
        firstVals = wrapNpLoad(filePath)
        gps_start, rest = filePath[len(pathPrefix):].split('-')
        self.duration = int(rest[:rest.find(inputExt)])
        self.next_output_timestamp = int(gps_start) * gst.SECOND

        # XXX The rate determination assumes the file contains N values with
        # t_i = gpsstart + (i-1)*Delta_t ,
        # such that the last data point is at time t = (gpsstart+duration)-Delta_t.
        # NOTE gstreamer expects an integer rate in Hz.
        # XXX This cast makes me uncomfortable.  That's why I added 0.1.
        self.rate = int(float(firstVals.size)/float(self.duration) + 0.1)
        # Now that we have the rate, we can set the caps for the appsrc
        self.caps = "audio/x-raw-float,width=64,depth=64,channels=1,rate=%d" % self.rate

    def check_for_new_files(self, timestamp=0):
        # get the file list
        pattern = self.pathPrefix + '*' + self.inputExt
        def is_current_file(path):
            filePath = os.path.basename(path)
            rest = filePath[len(self.inputPre):]
            if len(rest)>0:
                return int(rest.split('-')[0])>=timestamp
            else:
                return None
        filePathList = filter(is_current_file, glob.glob(pattern))
        filePathList.sort()
        return filePathList
        
    def need_data(self, src, need_bytes=None):
        # Check if new data has arrived on disk.
        self.fileQueue.extend(self.check_for_new_files(self.next_output_timestamp))
        try:
            # Get the gpsstart time from the filename.
            filePath = self.fileQueue.popleft()
            rest = filePath[len(self.pathPrefix):]
            gpsstart = int(rest.split('-')[0])
            # Let's re-derive the duration.  maybe it changed?
            rest = rest.split('-')[1]
            duration = int(rest[:rest.find(self.inputExt)])

            # Push gap if the current block is ahead of our timestamps.
            # FIXME: Give user some wiggle room to wait for the next file?
            if gpsstart * gst.SECOND > self.next_output_timestamp:
                # Build the buffer.
                buf = gst.buffer_new_and_alloc(0)
                buf.flag_set(gst.BUFFER_FLAG_GAP)
#                buf.timestamp = gpsstart * gst.SECOND
#                gap_duration = self.next_output_timestamp - gpsstart * gst.SECOND
#                buf.duration = gap_duration
                buf.timestamp = self.next_output_timestamp
                gap_duration = gpsstart * gst.SECOND - buf.timestamp
                gap_samples = int ((gap_duration / gst.SECOND) * self.rate)
                buf.duration = gap_duration
                buf.offset = self.current_offset
                # FIXME: What if the gap_duration != normal duration?
                # A: We'll need a next_output_offset as well
                buf.offset_end = self.current_offset + gap_samples
                src.emit("push-buffer", buf)

                self.next_output_timestamp += buf.duration 
                self.current_offset = buf.offset_end
                # FIXME I don't think we should actually be returning here.
                # return
           
            # Load the numpy array.
            veto_vals = wrapNpLoad(filePath)
            veto_vals = veto_vals.astype(np.float64)

            # Build the buffer.
            buffer_len = veto_vals.nbytes
            buf = gst.buffer_new_and_alloc(buffer_len)
            buf[:buffer_len-1] = np.getbuffer(veto_vals)
            buf.timestamp = gpsstart * gst.SECOND
            # gst buffers require:
            # buffer_duration * rate / gst.SECOND = (offset_end - offset)
            # The offset is zero since our data begin at the beginning 
            # of the buffer.
            buf.duration = duration * gst.SECOND
            buf.offset = self.current_offset
            buf.offset_end = self.current_offset + duration * self.rate
            buf.caps = self.caps

            self.next_output_timestamp += buf.duration 

            # Push the buffer into the stream (a side effect of 
            # emitting this signal).
            src.emit("push-buffer", buf)

            self.current_offset = buf.offset_end
        except IndexError:
            src.emit("end-of-stream")

#
# =============================================================================
#
#                                   Options
#
# =============================================================================
#

parser = OptionParser(description = __doc__)
datasource.append_options(parser)

parser.add_option("--frame-type", metavar = "name", help = "Specify the non-instrumental part of the frame type. The full frame type will be constructed by prepending the instrument.")
parser.add_option("--frame-duration", metavar = "s", default = 64, type = "int", help = "Set the duration of the output frames")
parser.add_option("--frames-per-file", metavar = "s", default = 1, type = "int", help = "Output frames per file")
parser.add_option("--output-channel-name", metavar = "name", help = "If an additional frame cache is requested, indicate the channel name to extract.")
parser.add_option("--output-path", metavar = "name", help = "Path to output frame files.")
parser.add_option("--output-type", metavar = "name", help = "Method of output. Valid choices are files 'files' (default) and shared memory 'shm'", default = "files")
parser.add_option("--shm-partition", metavar = "name", help = "Shared memory partition to write frames to. Required in case of output-type = shm, ignored otherwise.")
parser.add_option("--input-path", metavar = "name", help = "Path to input numpy files.")
parser.add_option("--input-prefix", metavar = "name", help = "Prefix for numpy files.")
parser.add_option("--input-ext", metavar = "name", help = "Extension for numpy files.", default = ".npy.gz")
parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose (optional).")

options, filenames = parser.parse_args()

#
# =============================================================================
#
#                                    Main
#
# =============================================================================
#

#
# Set the pipeline up
#

pipeline = gst.Pipeline("example")
mainloop = gobject.MainLoop()
handler = simplehandler.Handler(mainloop,pipeline)

#
# Setup.
#

channel_list = [options.output_channel_name,]
channel_dict = datasource.channel_dict_from_channel_list(channel_list)
# Assume instrument is the first (only) key of the channel dict
instrument = channel_dict.keys()[0]
channel_name = channel_dict[instrument]

# Check the instrument from the input_prefix
obsStr = options.input_prefix.split('-')[0]
if not instrument.startswith(obsStr):
    raise ValueError("Output channel instrument clashes with input prefix.")


# Setup the source class
vetoSrc = vetoSource(options.input_path, options.input_prefix,
            options.input_ext)

# Create the appsrc with accoutrements
appsrc = pipeparts.mkgeneric(pipeline, None, "appsrc", caps=gst.Caps(vetoSrc.caps), 
    format="time")
appsrc.connect('need-data', vetoSrc.need_data)
src = pipeparts.mktaginject(pipeline, appsrc, 
    "instrument=%s,channel-name=%s,units=cats" % (instrument, channel_name))

try:
    os.makedirs(options.output_path)
except Exception as e:
    print "Failed with %s" % e

# Define the muxer.
mux = pipeparts.mkframecppchannelmux(pipeline, None, 
    frames_per_file = options.frames_per_file, 
    frame_duration = options.frame_duration)

# Link the source to the muxer. 
src.get_pad("src").link(mux.get_pad(instrument + ':' + channel_name))

# Stick progress report in pipeline.
progRep = pipeparts.mkprogressreport(pipeline, mux, "progress_sink")

# Final destination.
# XXX multicast?
if options.output_type == "files":
    path = options.output_path
    if path:
        fs = pipeparts.mkframecppfilesink(pipeline, progRep, 
            frame_type = options.frame_type, path = options.output_path)
    else:
        fs = pipeparts.mkframecppfilesink(pipeline, progRep, 
            frame_type = options.frame_type)
elif options.output_type == "shm":
    lvshmsink = pipeparts.mkgeneric(pipeline, progRep, "gds_lvshmsink")
    lvshmsink.set_property("shm-name", options.shm_partition)
    lvshmsink.set_property("num-buffers", 10)
    # Let's guess the blocksize, then multiply by a fudge factor.
    # Note: This assumes the stream consists of 64-bit samples.
    # XXX This could be done better by attaching a pad probe to the lvhsmsink
    # sink pad and setting blocksize according to the size of the incoming buffers
    blocksize = vetoSrc.rate * 8 * options.frame_duration * options.frames_per_file
    blocksize = blocksize * 2
    lvshmsink.set_property("blocksize", blocksize)
    # FIXME: I think this means it needs to be read at least once
    lvshmsink.set_property("buffer-mode", 1)
else:
    raise ValueError("Invalid output type.")

#
# Start the thing going.
#

pipeline.set_state(gst.STATE_PLAYING)
mainloop.run()


#
# done
#
