ipnbdoctest.py 7.65 KB
Newer Older
1
2
3
4
5
6
#!/usr/bin/env python
"""
simple example script for running and testing notebooks.

Usage: `ipnbdoctest.py foo.ipynb [bar.ipynb [...]]`

7
8
Each cell is submitted to the kernel, and the outputs are compared
with those stored in the notebook.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
"""

from __future__ import print_function

import os,sys,time
import base64
import re
from difflib import unified_diff as diff

from collections import defaultdict
try:
    from queue import Empty
except ImportError:
    print('Python 3.x is needed to run this script.')
    sys.exit(77)

25
26
27
28
29
30
31
import imp
try:
    imp.find_module('IPython')
except:
    print('IPython is needed to run this script.')
    sys.exit(77)

32
33
34
35
36
try:
    from IPython.kernel import KernelManager
except ImportError:
    from IPython.zmq.blockingkernelmanager import BlockingKernelManager as KernelManager

37
38
# Until Debian ships IPython 3.0, we stick to the v3 format.
from IPython.nbformat import v3 as nbformat
39
40
41
42
43
44
45
46
47
48
49
50
51

def compare_png(a64, b64):
    """compare two b64 PNGs (incomplete)"""
    try:
        import Image
    except ImportError:
        pass
    adata = base64.decodestring(a64)
    bdata = base64.decodestring(b64)
    return True

def sanitize(s):
    """sanitize a string for comparison.
52
53
54

    fix universal newlines, strip trailing newlines, and normalize likely
    random values (memory addresses and UUIDs)
55
56
57
58
59
    """
    if not isinstance(s, str):
        return s
    # normalize newline:
    s = s.replace('\r\n', '\n')
60

61
62
    # ignore trailing newlines (but not space)
    s = s.rstrip('\n')
63

64
65
    # remove hex addresses:
    s = re.sub(r'at 0x[a-f0-9]+', 'object', s)
66

67
68
    # normalize UUIDs:
    s = re.sub(r'[a-f0-9]{8}(\-[a-f0-9]{4}){3}\-[a-f0-9]{12}', 'U-U-I-D', s)
69

70
71
72
    # SVG generated by graphviz may put note at different positions
    # depending on the graphviz build.  Let's just strip anything that
    # look like a position.
73
74
75
76
77
78
79
    s = re.sub(r'<path[^/]* d="[^"]*"', '<path', s)
    s = re.sub(r'points="[^"]*"', 'points=""', s)
    s = re.sub(r'x="[0-9.-]+"', 'x=""', s)
    s = re.sub(r'y="[0-9.-]+"', 'y=""', s)
    s = re.sub(r'width="[0-9.]+pt"', 'width=""', s)
    s = re.sub(r'height="[0-9.]+pt"', 'height=""', s)
    s = re.sub(r'viewBox="[0-9 .-]*"', 'viewbox=""', s)
80
    s = re.sub(r'transform="[^"]*"', 'transform=""', s)
81
82
83
84
85
86
87
88
    return s


def consolidate_outputs(outputs):
    """consolidate outputs into a summary dict (incomplete)"""
    data = defaultdict(list)
    data['stdout'] = ''
    data['stderr'] = ''
89

90
91
92
93
94
95
    for out in outputs:
        if out.type == 'stream':
            data[out.stream] += out.text
        elif out.type == 'pyerr':
            data['pyerr'] = dict(ename=out.ename, evalue=out.evalue)
        else:
96
97
            for key in ('png', 'svg', 'latex', 'html',
                        'javascript', 'text', 'jpeg',):
98
99
100
101
102
                if key in out:
                    data[key].append(out[key])
    return data


103
104
def compare_outputs(test, ref, skip_cmp=('png', 'traceback',
                                         'latex', 'prompt_number')):
105
106
107
108
    for key in ref:
        if key not in test:
            print("missing key: %s != %s" % (test.keys(), ref.keys()))
            return False
109
110
111
112
113
114
115
116
117
118
119
120
        elif key not in skip_cmp:
            exp = sanitize(ref[key])
            eff = sanitize(test[key])
            if exp != eff:
                print("mismatch %s:" % key)
                if exp[:-1] != '\n':
                    exp += '\n'
                if eff[:-1] != '\n':
                    eff += '\n'
                print(''.join(diff(exp.splitlines(1), eff.splitlines(1),
                                   fromfile='expected', tofile='effective')))
                return False
121
122
    return True

123
def _wait_for_ready_backport(kc):
124
125
126
127
128
129
130
131
132
133
134
135
136
    """Backport BlockingKernelClient.wait_for_ready from IPython 3"""
    # Wait for kernel info reply on shell channel
    kc.kernel_info()
    while True:
        msg = kc.get_shell_msg(block=True, timeout=30)
        if msg['msg_type'] == 'kernel_info_reply':
            break
    # Flush IOPub channel
    while True:
        try:
            msg = kc.get_iopub_msg(block=True, timeout=0.2)
        except Empty:
            break
137

138
def run_cell(kc, cell):
139
    # print cell.input
140
    kc.execute(cell.input)
141
    # wait for finish, maximum 20s
142
    kc.get_shell_msg(timeout=20)
143
    outs = []
144

145
146
    while True:
        try:
147
            msg = kc.get_iopub_msg(timeout=0.2)
148
149
150
        except Empty:
            break
        msg_type = msg['msg_type']
151
        if msg_type in ('status', 'pyin', 'execute_input'):
152
153
154
155
            continue
        elif msg_type == 'clear_output':
            outs = []
            continue
156

157
        content = msg['content']
158
159
160
161
162
163
        # print (msg_type, content)
        if msg_type == 'execute_result':
            msg_type = 'pyout'
        elif msg_type == 'error':
            msg_type = 'pyerr'
        out = nbformat.NotebookNode(output_type=msg_type)
164

165
166
        if msg_type == 'stream':
            out.stream = content['name']
167
168
169
170
            if 'text' in content:
                out.text = content['text']
            else:
                out.text = content['data']
171
172
173
174
175
176
177
        elif msg_type in ('display_data', 'pyout'):
            out['metadata'] = content['metadata']
            for mime, data in content['data'].items():
                attr = mime.split('/')[-1].lower()
                # this gets most right, but fix svg+html, plain
                attr = attr.replace('+xml', '').replace('plain', 'text')
                setattr(out, attr, data)
178
            if 'execution_count' in content:
179
180
181
182
183
184
185
                out.prompt_number = content['execution_count']
        elif msg_type == 'pyerr':
            out.ename = content['ename']
            out.evalue = content['evalue']
            out.traceback = content['traceback']
        else:
            print("unhandled iopub msg:", msg_type)
186

187
188
        outs.append(out)
    return outs
189

190
191

def test_notebook(nb):
192
193
    # run %pylab inline, because some notebooks assume this
    # even though they shouldn't
194
    km = KernelManager()
195
196
    km.start_kernel(extra_arguments=['--pylab=inline'],
                    stderr=open(os.devnull, 'w'))
197
198
199
200

    kc = km.client()
    kc.start_channels()

201
    try:
202
        kc.wait_for_ready()
203
    except AttributeError:
204
        _wait_for_ready_backport(kc)
205

206
207
208
209
210
211
212
213
    successes = 0
    failures = 0
    errors = 0
    for ws in nb.worksheets:
        for i, cell in enumerate(ws.cells):
            if cell.cell_type != 'code':
                continue
            try:
214
                outs = run_cell(kc, cell)
215
216
217
218
219
            except Exception as e:
                print("failed to run cell:", repr(e))
                print(cell.input)
                errors += 1
                continue
220

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
            failed = False
            if len(outs) != len(cell.outputs):
                print("output length mismatch (expected {}, got {})".format(
                      len(cell.outputs), len(outs)))
                failed = True
            for out, ref in zip(outs, cell.outputs):
                if not compare_outputs(out, ref):
                    failed = True
            print("cell %d: " % i, end="")
            if failed:
                print("FAIL")
                failures += 1
            else:
                print("OK")
                successes += 1

    print()
    print("tested notebook %s" % nb.metadata.name)
    print("    %3i cells successfully replicated" % successes)
    if failures:
        print("    %3i cells mismatched output" % failures)
    if errors:
        print("    %3i cells failed to complete" % errors)
    kc.stop_channels()
    km.shutdown_kernel()
    del km
    if failures | errors:
        sys.exit(1)

if __name__ == '__main__':
    for ipynb in sys.argv[1:]:
        print("testing %s" % ipynb)
        with open(ipynb) as f:
254
            nb = nbformat.reads_json(f.read())
255
        test_notebook(nb)