ipnbdoctest.py 8 KB
Newer Older
1 2 3 4 5 6
#!/usr/bin/env python
"""
simple example script for running and testing notebooks.

Usage: `ipnbdoctest.py foo.ipynb [bar.ipynb [...]]`

7 8
Each cell is submitted to the kernel, and the outputs are compared
with those stored in the notebook.
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
"""

from __future__ import print_function

import os,sys,time
import base64
import re
from difflib import unified_diff as diff

from collections import defaultdict
try:
    from queue import Empty
except ImportError:
    print('Python 3.x is needed to run this script.')
    sys.exit(77)

25 26 27 28 29 30 31
import imp
try:
    imp.find_module('IPython')
except:
    print('IPython is needed to run this script.')
    sys.exit(77)

32 33 34
try:
    from IPython.kernel import KernelManager
except ImportError:
35 36
    from IPython.zmq.blockingkernelmanager \
      import BlockingKernelManager as KernelManager
37

38 39
# Until Debian ships IPython 3.0, we stick to the v3 format.
from IPython.nbformat import v3 as nbformat
40 41 42 43 44 45 46 47 48 49 50 51 52

def compare_png(a64, b64):
    """compare two b64 PNGs (incomplete)"""
    try:
        import Image
    except ImportError:
        pass
    adata = base64.decodestring(a64)
    bdata = base64.decodestring(b64)
    return True

def sanitize(s):
    """sanitize a string for comparison.
53 54 55

    fix universal newlines, strip trailing newlines, and normalize likely
    random values (memory addresses and UUIDs)
56 57 58 59 60
    """
    if not isinstance(s, str):
        return s
    # normalize newline:
    s = s.replace('\r\n', '\n')
61

62 63
    # ignore trailing newlines (but not space)
    s = s.rstrip('\n')
64

65 66
    # remove hex addresses:
    s = re.sub(r'at 0x[a-f0-9]+', 'object', s)
67

68 69
    # normalize UUIDs:
    s = re.sub(r'[a-f0-9]{8}(\-[a-f0-9]{4}){3}\-[a-f0-9]{12}', 'U-U-I-D', s)
70

71 72 73
    # normalize graphviz version
    s = re.sub(r'Generated by graphviz version.*', 'VERSION', s)

74 75 76
    # SVG generated by graphviz may put note at different positions
    # depending on the graphviz build.  Let's just strip anything that
    # look like a position.
77 78 79 80 81 82 83
    s = re.sub(r'<path[^/]* d="[^"]*"', '<path', s)
    s = re.sub(r'points="[^"]*"', 'points=""', s)
    s = re.sub(r'x="[0-9.-]+"', 'x=""', s)
    s = re.sub(r'y="[0-9.-]+"', 'y=""', s)
    s = re.sub(r'width="[0-9.]+pt"', 'width=""', s)
    s = re.sub(r'height="[0-9.]+pt"', 'height=""', s)
    s = re.sub(r'viewBox="[0-9 .-]*"', 'viewbox=""', s)
84
    s = re.sub(r'transform="[^"]*"', 'transform=""', s)
85 86 87 88 89 90 91 92
    return s


def consolidate_outputs(outputs):
    """consolidate outputs into a summary dict (incomplete)"""
    data = defaultdict(list)
    data['stdout'] = ''
    data['stderr'] = ''
93

94 95 96 97 98 99
    for out in outputs:
        if out.type == 'stream':
            data[out.stream] += out.text
        elif out.type == 'pyerr':
            data['pyerr'] = dict(ename=out.ename, evalue=out.evalue)
        else:
100 101
            for key in ('png', 'svg', 'latex', 'html',
                        'javascript', 'text', 'jpeg',):
102 103 104 105 106
                if key in out:
                    data[key].append(out[key])
    return data


107 108
def compare_outputs(test, ref, skip_cmp=('png', 'traceback',
                                         'latex', 'prompt_number')):
109 110 111 112
    for key in ref:
        if key not in test:
            print("missing key: %s != %s" % (test.keys(), ref.keys()))
            return False
113 114 115 116 117 118 119 120 121 122 123 124
        elif key not in skip_cmp:
            exp = sanitize(ref[key])
            eff = sanitize(test[key])
            if exp != eff:
                print("mismatch %s:" % key)
                if exp[:-1] != '\n':
                    exp += '\n'
                if eff[:-1] != '\n':
                    eff += '\n'
                print(''.join(diff(exp.splitlines(1), eff.splitlines(1),
                                   fromfile='expected', tofile='effective')))
                return False
125 126
    return True

127
def _wait_for_ready_backport(kc):
128 129 130 131 132 133 134 135 136 137 138 139 140
    """Backport BlockingKernelClient.wait_for_ready from IPython 3"""
    # Wait for kernel info reply on shell channel
    kc.kernel_info()
    while True:
        msg = kc.get_shell_msg(block=True, timeout=30)
        if msg['msg_type'] == 'kernel_info_reply':
            break
    # Flush IOPub channel
    while True:
        try:
            msg = kc.get_iopub_msg(block=True, timeout=0.2)
        except Empty:
            break
141

142
def run_cell(kc, cell):
143
    # print cell.input
144
    kc.execute(cell.input)
145
    # wait for finish, maximum 20s
146
    kc.get_shell_msg(timeout=20)
147
    outs = []
148

149 150
    while True:
        try:
151
            msg = kc.get_iopub_msg(timeout=0.2)
152 153 154
        except Empty:
            break
        msg_type = msg['msg_type']
155
        if msg_type in ('status', 'pyin', 'execute_input'):
156 157 158 159
            continue
        elif msg_type == 'clear_output':
            outs = []
            continue
160

161
        content = msg['content']
162 163 164 165 166 167
        # print (msg_type, content)
        if msg_type == 'execute_result':
            msg_type = 'pyout'
        elif msg_type == 'error':
            msg_type = 'pyerr'
        out = nbformat.NotebookNode(output_type=msg_type)
168

169 170
        if msg_type == 'stream':
            out.stream = content['name']
171 172 173 174
            if 'text' in content:
                out.text = content['text']
            else:
                out.text = content['data']
175 176 177 178 179 180 181
        elif msg_type in ('display_data', 'pyout'):
            out['metadata'] = content['metadata']
            for mime, data in content['data'].items():
                attr = mime.split('/')[-1].lower()
                # this gets most right, but fix svg+html, plain
                attr = attr.replace('+xml', '').replace('plain', 'text')
                setattr(out, attr, data)
182
            if 'execution_count' in content:
183 184 185 186 187
                out.prompt_number = content['execution_count']
        elif msg_type == 'pyerr':
            out.ename = content['ename']
            out.evalue = content['evalue']
            out.traceback = content['traceback']
188 189 190 191

            # sys.exit(77) is used to Skip the test.
            if out.ename == 'SystemExit' and out.evalue == '77':
                sys.exit(77)
192 193
        else:
            print("unhandled iopub msg:", msg_type)
194

195 196
        outs.append(out)
    return outs
197

198 199 200

def test_notebook(nb):
    km = KernelManager()
201 202 203
    # Do not save the history to disk, as it can yield spurious lock errors.
    # See https://github.com/ipython/ipython/issues/2845
    km.start_kernel(extra_arguments=['--HistoryManager.hist_file=:memory:'],
204
                    stderr=open(os.devnull, 'w'))
205 206 207 208

    kc = km.client()
    kc.start_channels()

209
    try:
210
        kc.wait_for_ready()
211
    except AttributeError:
212
        _wait_for_ready_backport(kc)
213

214 215 216 217 218
    successes = 0
    failures = 0
    errors = 0
    for ws in nb.worksheets:
        for i, cell in enumerate(ws.cells):
219
            if cell.cell_type != 'code' or cell.input.startswith('%timeit'):
220 221
                continue
            try:
222
                outs = run_cell(kc, cell)
223 224 225 226 227
            except Exception as e:
                print("failed to run cell:", repr(e))
                print(cell.input)
                errors += 1
                continue
228

229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
            failed = False
            if len(outs) != len(cell.outputs):
                print("output length mismatch (expected {}, got {})".format(
                      len(cell.outputs), len(outs)))
                failed = True
            for out, ref in zip(outs, cell.outputs):
                if not compare_outputs(out, ref):
                    failed = True
            print("cell %d: " % i, end="")
            if failed:
                print("FAIL")
                failures += 1
            else:
                print("OK")
                successes += 1

    print()
    print("tested notebook %s" % nb.metadata.name)
    print("    %3i cells successfully replicated" % successes)
    if failures:
        print("    %3i cells mismatched output" % failures)
    if errors:
        print("    %3i cells failed to complete" % errors)
    kc.stop_channels()
    km.shutdown_kernel()
    del km
    if failures | errors:
        sys.exit(1)

if __name__ == '__main__':
    for ipynb in sys.argv[1:]:
        print("testing %s" % ipynb)
        with open(ipynb) as f:
262
            nb = nbformat.reads_json(f.read())
263
        test_notebook(nb)