# Copyright (c) 2005-2007 Forest Bond.
# This file is part of the sclapp software package.
# 
# sclapp is free software; you can redistribute it and/or modify it under the
# terms of the GNU General Public License version 2 as published by the Free
# Software Foundation.
# 
# A copy of the license has been included in the COPYING file.

import sys, os, re, unittest, doctest

import sclapp
from sclapp import debug_logging

def execHead(num_lines = 5):
    pid = os.fork()
    if not pid:
        os.execvp('head', [ 'head', '-n%u' % num_lines ])
    return pid

CAUGHT_SIGNALS_REGEX = r'pid ([0-9]+) caught signals: ([0-9, ]*)'
caught_signals_regex_compiled = re.compile(CAUGHT_SIGNALS_REGEX)

def dumpLogFile():
    print 'Logfile contents:'
    print 80 * '-'
    sys.stdout.write(debug_logging.readLogFile())
    print 80 * '-'

def getLoggedSignals(pid):
    contents = debug_logging.readLogFile()
    for line in contents.split('\n'):
        match = caught_signals_regex_compiled.match(line)
        if match is not None:
            gs = match.groups()
            assert(len(gs) == 2)
            logged_pid = int(gs[0])
            if pid == logged_pid:
                signums = [ int(x) for x in [
                  x.strip() for x in gs[1].split(',')
                ] if x ]
                return signums

def verifySignalCaught(signum, pid):
    signums = getLoggedSignals(pid)
    return ((signums is not None) and (signum in signums))

def assertSignalCaught(signum, pid):
    assert verifySignalCaught(signum, pid)

def logSignals():
    debug_logging.logMessage('pid %u caught signals: %s' % \
      (os.getpid(), ', '.join([str(x) for x in sclapp.getCaughtSignals()])))

def waitForPid(pid):
    return os.waitpid(pid, 0)

def removeLogFile():
    try:
        return debug_logging.removeLogFile()
    except (OSError, IOError):
        pass

def redirectToLogFile():
    from sclapp import processes as s_processes
    return s_processes.redirectFds(
      stdout = debug_logging.DEBUG_LOGFILE,
      stderr = debug_logging.DEBUG_LOGFILE
    )

def assertLogFileContains(needle):
    haystack = debug_logging.readLogFile()
    assert (haystack.find(needle) > -1)

def assertLogFileDoesNotContain(needle):
    haystack = debug_logging.readLogFile()
    assert (haystack.find(needle) < 0)

def grepCount(haystack, needle):
    count = 0
    i = -1
    while True:
        i = haystack.find(needle, i + 1)
        if i == -1:
            break
        count = count + 1
    return count

def assertLogFileContainsExactly(needle, num):
    haystack = debug_logging.readLogFile()
    count = grepCount(haystack, needle)
    assert (count == num), (
      'Expected exactly %u, found %u.  Logfile:\n%s' % (
        num, count, debug_logging.readLogFile()
    ))

def assertLogFileContainsAtLeast(needle, min):
    haystack = debug_logging.readLogFile()
    count = grepCount(haystack, needle)
    assert (count >= min), (
      'Expected at least %u, found %u.  Logfile:\n%s' % (
        min, count, debug_logging.readLogFile()
    ))

def assertLogFileContainsAtMost(needle, max):
    haystack = debug_logging.readLogFile()
    count = grepCount(haystack, needle)
    assert (count <= max), (
      'Expected at most %u, found %u.  Logfile:\n%s' % (
        max, count, debug_logging.readLogFile()
    ))

class SclappTestCase(unittest.TestCase):
    def setUp(self):
        removeLogFile()
        return super(SclappTestCase, self).setUp()

    def tearDown(self):
        removeLogFile()
        return super(SclappTestCase, self).tearDown()

# For some reason, it is important that the class is defined in a separate
# namespace.  If it is not defined in a separate namespace (i.e., directly in
# getAllDocTestCasesFromModule, below), the same doctest is used for all test
# cases.  This is pretty strange, but it's probably best to separate this,
# anyway.
def _makeDocTestCase(test):
    class _SclappDocTestCase(doctest.DocTestCase):
        def __init__(self, *args, **kwargs):
            doctest.DocTestCase.__init__(self, test)
    _SclappDocTestCase.__name__ = '%s_TestCase' % test.name.split('.')[-1]
    return _SclappDocTestCase

def getAllDocTestCasesFromModule(name):
    from sclapp.util import importName

    mod = importName(name)

    finder = doctest.DocTestFinder()
    tests = finder.find(mod)

    doc_test_cases = [ ]
    for test in tests:
        doc_test_cases.append(_makeDocTestCase(test))
    return doc_test_cases

def defineDocTestCasesFromModule(dest_name, src_name = None):
    if src_name is None:
        src_name = dest_name

    from sclapp.util import importName
    dest_mod = importName(dest_name)
    for test_case in getAllDocTestCasesFromModule(src_name):
        test_case.__module__ = dest_mod
        dest_mod.__dict__[test_case.__name__] = test_case

def defineDocTestCasesFromTextFile(filename, dest_name):
    from sclapp.util import importName
    dest_mod = importName(dest_name)

    f = open(filename, 'r')
    try:
        parser = doctest.DocTestParser()
        test = parser.get_doctest(
          f.read(),
          globs = dest_mod.__dict__,
          name = filename,
          filename = filename,
          lineno = 0,
        )
    finally:
        f.close()

    test_case = _makeDocTestCase(test)
    dest_mod.__dict__[test_case.__name__] = test_case
