#!/usr/bin/env python
"""
simple example script for running and testing notebooks.

Usage: `ipnbdoctest.py foo.ipynb [bar.ipynb [...]]`

Each cell is submitted to the kernel, and the outputs are compared
with those stored in the notebook.
"""

from __future__ import print_function

import os,sys,time
import base64
import re
from difflib import unified_diff as diff

from collections import defaultdict
try:
    from queue import Empty
except ImportError:
    print('Python 3.x is needed to run this script.')
    sys.exit(77)

import imp
try:
    imp.find_module('IPython')
except:
    print('IPython is needed to run this script.')
    sys.exit(77)

try:
    from IPython.kernel import KernelManager
except ImportError:
    from IPython.zmq.blockingkernelmanager \
      import BlockingKernelManager as KernelManager

# Until Debian ships IPython 3.0, we stick to the v3 format.
from IPython.nbformat import v3 as nbformat

def compare_png(a64, b64):
    """compare two b64 PNGs (incomplete)"""
    try:
        import Image
    except ImportError:
        pass
    adata = base64.decodestring(a64)
    bdata = base64.decodestring(b64)
    return True

def sanitize(s):
    """sanitize a string for comparison.

    fix universal newlines, strip trailing newlines, and normalize likely
    random values (memory addresses and UUIDs)
    """
    if not isinstance(s, str):
        return s
    # normalize newline:
    s = s.replace('\r\n', '\n')

    # ignore trailing newlines (but not space)
    s = s.rstrip('\n')

    # remove hex addresses:
    s = re.sub(r'at 0x[a-f0-9]+', 'object', s)

    # normalize UUIDs:
    s = re.sub(r'[a-f0-9]{8}(\-[a-f0-9]{4}){3}\-[a-f0-9]{12}', 'U-U-I-D', s)

    # normalize graphviz version
    s = re.sub(r'Generated by graphviz version.*', 'VERSION', s)

    # remove Spins verbose output version
    s = re.sub(r'SpinS Promela Compiler.*Compiled C .* to .*pml.spins',
               'SpinS output', s, flags=re.DOTALL)

    # SVG generated by graphviz may put note at different positions
    # depending on the graphviz build.  Let's just strip anything that
    # look like a position.
    s = re.sub(r'<path[^/]* d="[^"]*"', '<path', s)
    s = re.sub(r'points="[^"]*"', 'points=""', s)
    s = re.sub(r'x="[0-9.-]+"', 'x=""', s)
    s = re.sub(r'y="[0-9.-]+"', 'y=""', s)
    s = re.sub(r'width="[0-9.]+pt"', 'width=""', s)
    s = re.sub(r'height="[0-9.]+pt"', 'height=""', s)
    s = re.sub(r'viewBox="[0-9 .-]*"', 'viewbox=""', s)
    s = re.sub(r'transform="[^"]*"', 'transform=""', s)
    return s


def consolidate_outputs(outputs):
    """consolidate outputs into a summary dict (incomplete)"""
    data = defaultdict(list)
    data['stdout'] = ''
    data['stderr'] = ''

    for out in outputs:
        if out.type == 'stream':
            data[out.stream] += out.text
        elif out.type == 'pyerr':
            data['pyerr'] = dict(ename=out.ename, evalue=out.evalue)
        else:
            for key in ('png', 'svg', 'latex', 'html',
                        'javascript', 'text', 'jpeg',):
                if key in out:
                    data[key].append(out[key])
    return data


def compare_outputs(test, ref, skip_cmp=('png', 'traceback',
                                         'latex', 'prompt_number')):
    for key in ref:
        if key not in test:
            print("missing key: %s != %s" % (test.keys(), ref.keys()))
            return False
        elif key not in skip_cmp:
            exp = sanitize(ref[key])
            eff = sanitize(test[key])
            if exp != eff:
                print("mismatch %s:" % key)
                if exp[:-1] != '\n':
                    exp += '\n'
                if eff[:-1] != '\n':
                    eff += '\n'
                print(''.join(diff(exp.splitlines(1), eff.splitlines(1),
                                   fromfile='expected', tofile='effective')))
                return False
    return True

def _wait_for_ready_backport(kc):
    """Backport BlockingKernelClient.wait_for_ready from IPython 3"""
    # Wait for kernel info reply on shell channel
    kc.kernel_info()
    while True:
        msg = kc.get_shell_msg(block=True, timeout=30)
        if msg['msg_type'] == 'kernel_info_reply':
            break
    # Flush IOPub channel
    while True:
        try:
            msg = kc.get_iopub_msg(block=True, timeout=0.2)
        except Empty:
            break

def run_cell(kc, cell):
    # print cell.input
    kc.execute(cell.input)
    # wait for finish, maximum 20s
    kc.get_shell_msg(timeout=20)
    outs = []

    while True:
        try:
            msg = kc.get_iopub_msg(timeout=0.2)
        except Empty:
            break
        msg_type = msg['msg_type']
        if msg_type in ('status', 'pyin', 'execute_input'):
            continue
        elif msg_type == 'clear_output':
            outs = []
            continue

        content = msg['content']
        # print (msg_type, content)
        if msg_type == 'execute_result':
            msg_type = 'pyout'
        elif msg_type == 'error':
            msg_type = 'pyerr'
        out = nbformat.NotebookNode(output_type=msg_type)

        if msg_type == 'stream':
            out.stream = content['name']
            if 'text' in content:
                out.text = content['text']
            else:
                out.text = content['data']
        elif msg_type in ('display_data', 'pyout'):
            out['metadata'] = content['metadata']
            for mime, data in content['data'].items():
                attr = mime.split('/')[-1].lower()
                # this gets most right, but fix svg+html, plain
                attr = attr.replace('+xml', '').replace('plain', 'text')
                setattr(out, attr, data)
            if 'execution_count' in content:
                out.prompt_number = content['execution_count']
        elif msg_type == 'pyerr':
            out.ename = content['ename']
            out.evalue = content['evalue']
            out.traceback = content['traceback']

            # sys.exit(77) is used to Skip the test.
            if out.ename == 'SystemExit' and out.evalue == '77':
                sys.exit(77)
        else:
            print("unhandled iopub msg:", msg_type)

        outs.append(out)
    return outs


def test_notebook(nb):
    km = KernelManager()
    # Do not save the history to disk, as it can yield spurious lock errors.
    # See https://github.com/ipython/ipython/issues/2845
    km.start_kernel(extra_arguments=['--HistoryManager.hist_file=:memory:'],
                    stderr=open(os.devnull, 'w'))

    kc = km.client()
    kc.start_channels()

    try:
        kc.wait_for_ready()
    except AttributeError:
        _wait_for_ready_backport(kc)

    successes = 0
    failures = 0
    errors = 0
    for ws in nb.worksheets:
        for i, cell in enumerate(ws.cells):
            if cell.cell_type != 'code' or cell.input.startswith('%timeit'):
                continue
            try:
                outs = run_cell(kc, cell)
            except Exception as e:
                print("failed to run cell:", repr(e))
                print(cell.input)
                errors += 1
                continue

            failed = False
            if len(outs) != len(cell.outputs):
                print("output length mismatch (expected {}, got {})".format(
                      len(cell.outputs), len(outs)))
                failed = True
            for out, ref in zip(outs, cell.outputs):
                if not compare_outputs(out, ref):
                    failed = True
            print("cell %d: " % i, end="")
            if failed:
                print("FAIL")
                failures += 1
            else:
                print("OK")
                successes += 1

    print()
    print("tested notebook %s" % nb.metadata.name)
    print("    %3i cells successfully replicated" % successes)
    if failures:
        print("    %3i cells mismatched output" % failures)
    if errors:
        print("    %3i cells failed to complete" % errors)
    kc.stop_channels()
    km.shutdown_kernel()
    del km
    if failures | errors:
        sys.exit(1)

if __name__ == '__main__':
    for ipynb in sys.argv[1:]:
        print("testing %s" % ipynb)
        with open(ipynb) as f:
            nb = nbformat.reads_json(f.read())
        test_notebook(nb)
