#
# Mock-up NDS 1 server
#
# Based on:
#   * https://dcc.ligo.org/T980024
#   * https://dcc.ligo.org/T0900636
#   * and reverse-engineering packet formats with tcpdump
#
# Copyright (C) 2014  Leo Singer <leo.singer@ligo.org>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

from __future__ import print_function
import binascii
import collections
import os
import re
import shutil
import socket
import stat
import subprocess
import struct
import sys
import tempfile
import threading

Channel = collections.namedtuple(
    'Channel',
    'name rate type tp_num group unit gain slope offset bytepattern')


channels = [
    Channel('X1:A', 16, 1, 0, 0, 'Undef', 1, 1, 0, binascii.unhexlify('B000')),
    Channel('X1:B', 16, 2, 0, 0, 'Undef', 1, 1, 0, binascii.unhexlify('DEADBEEF')),
    Channel('X1:C', 16, 3, 0, 0, 'Undef', 1, 1, 0, binascii.unhexlify('00000000FEEDFACE')),
    Channel('X1:D', 16, 4, 0, 0, 'Undef', 1, 1, 0, binascii.unhexlify('600DF00D')),
    Channel('X1:E', 16, 5, 0, 0, 'Undef', 1, 1, 0, binascii.unhexlify('BAADF00D33333333')),
    Channel('X1:F', 16, 6, 0, 0, 'Undef', 1, 1, 0, binascii.unhexlify('600DDEEDBAADFACE')),
]
channels_dict = dict((channel.name, channel) for channel in channels)


class UnknownChannelError(ValueError):
    pass


def channels_for_cmd(cmd):
    try:
        left_brace_index = cmd.index('{')
    except ValueError:
        return channels

    s = cmd[left_brace_index:].lstrip('{').rstrip('}')
    names = [field.strip('"') for field in s.split()]
    try:
        return [channels_dict[name] for name in names]
    except KeyError:
        raise UnknownChannelError


def hex_float(f):
    return binascii.hexlify(struct.pack('>f', f))


def serve(conn):
    f = conn.makefile('rwb')
    try:
        for cmd in f:
            cmd = cmd.strip().rstrip(';')
            print("MOCK SERVE CMD: '%s'" % cmd, file=sys.stderr)
            try:
                if cmd == '':
                    continue
                elif cmd == 'version':
                    f.write('0000000c')
                elif cmd == 'revision':
                    f.write('00000000')
                elif cmd[:17] == 'status channels 2':
                    # FIXME: HACK for broken servers that return
                    # errant tabs at the end of channel names when
                    # individual channels are requested.
                    hack = ''
                    if os.getenv('NDS1_MOCKUP_SERVER_BROKEN', False) \
                       and '{' in cmd:
                        hack = '\t'
                    # FIXME: END HACK
                    chans = channels_for_cmd(cmd)
                    f.write('0000{0:08X}'.format(len(chans)))
                    for channel in chans:
                        s = '{name:60s}{hack:s}{rate:08X}{tp_num:08X}{group:04X}{type:04X}{hex_gain:s}{hex_slope:s}{hex_offset:s}{unit:40s}'.format(
                            name=channel.name,
                            rate=channel.rate,
                            tp_num=channel.tp_num,
                            group=channel.group,
                            type=channel.type,
                            unit=channel.unit,
                            hex_gain=hex_float(channel.gain),
                            hex_slope=hex_float(channel.slope),
                            hex_offset=hex_float(channel.offset),
                            hack=hack,
                        )
                        f.write(s)
                elif cmd[:17] == 'status channels 3':
                    chans = channels_for_cmd(cmd)
                    f.write('0000{0:d}\n'.format(len(chans)))
                    for channel in chans:
                        # Skip last field, which is hex data pattern
                        for field in channel[:-1]:
                            f.write(str(field) + '\n')
                elif cmd.startswith('start net-writer'):
                    chans = channels_for_cmd(cmd)
                    match = re.match(r'start net-writer (\d+) (\d+)', cmd)
                    if match:
                        offline = True
                        timestamp = int(match.group(1))
                        duration = int(match.group(2))
                    else:
                        timestamp = 1000000000
                        duration = float('inf')
                        offline = False
                    f.write('0' * 12)
                    f.write(struct.pack('>I', offline))
                    f.write(struct.pack('>Iiiii',
                        len(chans) * 12 + 16,
                        -1, -1, 3, -1))
                    for channel in chans:
                        # WTF! randomly in little-endian order?
                        f.write(struct.pack('<ffI', channel.offset, channel.slope, 0))
                    bytes_per_buffer = sum(len(channel.bytepattern) * channel.rate for channel in chans)
                    seq_num = 0
                    while duration > 0:
                        f.write(struct.pack('>IIIII',
                            16 + bytes_per_buffer,
                            1, timestamp, 0, seq_num))
                        for channel in chans:
                            f.write(channel.bytepattern[::-1] * channel.rate)
                        seq_num += 1
                        timestamp += 1
                        duration -= 1
                    f.write(struct.pack('>IIIII',
                        16, 0, 0, 0, 0))
                    f.flush()
                elif cmd == 'quit':
                    break
                else:
                    # NDS1 servers drop the connection upon receiving
                    # an unknown command
                    raise ValueError("unknown command: %s" % cmd)
            except UnknownChannelError:
                f.write('0004')
            f.flush()
    except IOError:
        pass # ignore, just try close socket
    finally:
        # try to close socket, don't really care if any of this fails
        try:
            f.close()
        except:
            pass
        try:
            conn.shutdown(socket.SHUT_RDWR)
        except:
            pass
        try:
            conn.close()
        except:
            pass


# Create socket bound to a random ephemeral port
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
sock.bind(('127.0.0.1', 0))
_, port = sock.getsockname()

def listen():
    while True:
        sock.listen(2)
        conn, _ = sock.accept()
        serve_thread = threading.Thread(target=serve, args=(conn,))
        serve_thread.daemon = True
        serve_thread.start()


listen_thread = threading.Thread(target=listen)
listen_thread.daemon = True
listen_thread.start()

os.environ['NDS_TEST_PORT'] = str(port)
db_dir = tempfile.mkdtemp()
try:
    os.environ['NDS1_MOCKUP_SERVER_BROKEN'] = 't'
    os.environ['NDS2_CHANNEL_DB_DIR'] = db_dir
    print('Testing with read-writable database...', file=sys.stderr)
    status = subprocess.call(sys.argv[1:])
    if status == 0:
        for filename in os.listdir(db_dir):
            os.chmod(os.path.join(db_dir, filename), stat.S_IREAD)
        print('Testing with read-only database...', file=sys.stderr)
        status = subprocess.call(sys.argv[1:])
finally:
    shutil.rmtree(db_dir)
sys.exit(status)
