###############################################################################
#                                                                             #
# Copyright (C) 2008-2014 Edward d'Auvergne                                   #
#                                                                             #
# This file is part of the program relax (http://www.nmr-relax.com).          #
#                                                                             #
# This program is free software: you can redistribute it and/or modify        #
# it under the terms of the GNU General Public License as published by        #
# the Free Software Foundation, either version 3 of the License, or           #
# (at your option) any later version.                                         #
#                                                                             #
# This program is distributed in the hope that it will be useful,             #
# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
# GNU General Public License for more details.                                #
#                                                                             #
# You should have received a copy of the GNU General Public License           #
# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#                                                                             #
###############################################################################

# Python module imports.
import dep_check
import sys
from time import time
from unittest import TextTestRunner
if dep_check.wx_module:
    import wx

# relax module imports.
from lib.compat import StringIO
from lib.compat import TextTestResult
from status import Status; status = Status()


class RelaxTestResult(TextTestResult):
    """A replacement for the TextTestResult class.

    This class is designed to catch STDOUT and STDERR during the execution of each test and to
    prepend the output to the failure and error reports normally generated by TextTestRunner.
    """

    def __init__(self, stream, descriptions, verbosity, timing=False, category=None):
        """Initialise the RelaxTestResult object with relax specific variables.

        @keyword timing:    A flag which if True will enable timing of individual tests.
        @type timing:       bool
        @keyword category:  The type of test being performed, to allow the printouts to be changed.  This can be one of 'system', 'unit', 'gui', or 'verification'. the printout.
        @type category:     str
        """

        # Normal setup.
        super(RelaxTestResult, self).__init__(stream, descriptions, verbosity)

        # Store the timing flag and category.
        self.timing_flag = timing
        self.category = category


    def addError(self, test, err):
        """Override of the TestResult.addError() method.

        The STDOUT and STDERR captured text is prepended to the error text here.


        @param test:    The test object.
        @type test:     TestCase instance
        @param err:     A tuple of values as returned by sys.exc_info().
        @type err:      tuple of values
        """

        # Execute the base class method to print the 'E' and handle the error.
        super(RelaxTestResult, self).addError(test, err)

        # Prepend the STDOUT and STDERR messages to the second element of the tuple.
        self.errors[-1] = (self.errors[-1][0], self.capt.getvalue() + self.errors[-1][1])

        # Write out timing info.
        if self.timing_flag:
            self.write_time(test.id())


    def addFailure(self, test, err):
        """Override of the TestResult.addFailure() method.

        The STDOUT and STDERR captured text is prepended to the failure text here.


        @param test:    The test object.
        @type test:     TestCase instance
        @param err:     A tuple of values as returned by sys.exc_info().
        @type err:      tuple of values
        """

        # Execute the base class method to print the 'F' and handle the failure.
        super(RelaxTestResult, self).addFailure(test, err)

        # Prepend the STDOUT and STDERR messages to the second element of the tuple.
        self.failures[-1] = (self.failures[-1][0], self.capt.getvalue() + self.failures[-1][1])

        # Write out timing info.
        if self.timing_flag:
            self.write_time(test.id())


    def addSuccess(self, test):
        """The method for a successful test.

        @param test:    The test object.
        @type test:     TestCase instance
        """

        # Execute the base class method to print the '.'.
        super(RelaxTestResult, self).addSuccess(test)

        # Write out timing info.
        if self.timing_flag:
            self.write_time(test.id())


    def startTest(self, test):
        """Override of the TextTestResult.startTest() method.

        The start of STDOUT and STDERR capture occurs here.
        """

        # Store the original STDOUT and STDERR for restoring later on.
        self.orig_stdout = sys.stdout
        self.orig_stderr = sys.stderr

        # Catch stdout and stderr.
        self.capt = StringIO()
        if not status.debug:
            sys.stdout = self.capt
            sys.stderr = self.capt

        # Place the test name in the status object.
        status.exec_lock.test_name = str(test)

        # Store the starting time.
        if self.timing_flag:
            self.time = time()

        # Execute the normal startTest method.
        super(RelaxTestResult, self).startTest(test)


    def stopTest(self, test):
        """Override of the TextTestResult.stopTest() method.

        The end of STDOUT and STDERR capture occurs here.
        """

        # Restore the IO streams.
        sys.stdout = self.orig_stdout
        sys.stderr = self.orig_stderr


    def write_time(self, test_name):
        """Write the timing of the test to the stream.

        @param test_name:   The TestCase name.
        @type test_name:    str
        """

        # Subtract the end time from the start time.
        self.time -= time()

        # Change the test name.
        if self.category != 'unit':
            test_name = test_name.split('.')
            test_name = "%s.%s" % (test_name[-2], test_name[-1])

        # The printout.
        self.stream.write('  %7.2f s for %s\n' % (-self.time, test_name))



class GuiTestResult(RelaxTestResult):
    """A replacement for the TextTestResult class for the GUI."""

    def stopTest(self, test):
        """Override of the RelaxTestResult.stopTest() method.

        The end of STDOUT and STDERR capture occurs here.
        """

        # Execute the RelaxTestResult.stopTest() method.
        super(GuiTestResult, self).stopTest(test)

        # Yield to allow the GUI to be updated.
        wx.GetApp().Yield(True)



class RelaxTestRunner(TextTestRunner):
    """A replacement unittest runner.

    This runner is designed to catch STDOUT during the execution of each test and to prepend the
    output to the failure and error reports normally generated by TextTestRunner.
    """

    # Variable for specifying the type of test being performed, to change the printout.
    category = None

    def __init__(self, stream=sys.stderr, descriptions=True, verbosity=1, failfast=False, buffer=False, resultclass=None, timing=False):
        """Initialise the class, storing the timing flag.

        @keyword timing:        A flag which if True will enable timing of individual tests.
        @type timing:           bool
        """

        # Execute the base method (with different Python version compatibility).
        if (sys.version_info[0] == 3 and sys.version_info[1] == 1) or (sys.version_info[0] == 2 and sys.version_info[1] <= 6):
            super(RelaxTestRunner, self).__init__(stream=stream, descriptions=descriptions, verbosity=verbosity)
        else:
            super(RelaxTestRunner, self).__init__(stream=stream, descriptions=descriptions, verbosity=verbosity, failfast=failfast, buffer=buffer, resultclass=resultclass)

        # Store the flag.
        self.timing_flag = timing


    def _makeResult(self):
        """Override of the TextTestRunner._makeResult() method."""

        return RelaxTestResult(self.stream, self.descriptions, self.verbosity, timing=self.timing_flag, category=self.category)



class GuiTestRunner(TextTestRunner):
    """A replacement unittest runner.

    This runner is designed to catch STDOUT during the execution of each test and to prepend the
    output to the failure and error reports normally generated by TextTestRunner.
    """

    def __init__(self, stream=sys.stderr, descriptions=True, verbosity=1, failfast=False, buffer=False, resultclass=None, timing=False):
        """Initialise the class, storing the timing flag.

        @keyword timing:        A flag which if True will enable timing of individual tests.
        @type timing:           bool
        """

        # Execute the base method.
        if (sys.version_info[0] == 3 and sys.version_info[1] == 1) or (sys.version_info[0] == 2 and sys.version_info[1] <= 6):
            super(GuiTestRunner, self).__init__(stream=stream, descriptions=descriptions, verbosity=verbosity)
        else:
            super(GuiTestRunner, self).__init__(stream=stream, descriptions=descriptions, verbosity=verbosity, failfast=failfast, buffer=buffer, resultclass=resultclass)

        # Store the flag.
        self.timing_flag = timing


    def _makeResult(self):
        """Override of the TextTestRunner._makeResult() method."""

        return GuiTestResult(self.stream, self.descriptions, self.verbosity, timing=self.timing_flag)
