from __future__ import print_function, absolute_import, division import contextlib import gc import pickle import subprocess import sys from numba import unittest_support as unittest from numba.errors import TypingError from numba.targets import registry from .support import TestCase, tag from .serialize_usecases import * class TestDispatcherPickling(TestCase): def run_with_protocols(self, meth, *args, **kwargs): for proto in range(pickle.HIGHEST_PROTOCOL + 1): meth(proto, *args, **kwargs) @contextlib.contextmanager def simulate_fresh_target(self): dispatcher_cls = registry.dispatcher_registry['cpu'] old_descr = dispatcher_cls.targetdescr # Simulate fresh targetdescr dispatcher_cls.targetdescr = type(dispatcher_cls.targetdescr)() try: yield finally: # Be sure to reinstantiate old descriptor, otherwise other # objects may be out of sync. dispatcher_cls.targetdescr = old_descr def check_call(self, proto, func, expected_result, args): def check_result(func): if (isinstance(expected_result, type) and issubclass(expected_result, Exception)): self.assertRaises(expected_result, func, *args) else: self.assertPreciseEqual(func(*args), expected_result) # Control check_result(func) pickled = pickle.dumps(func, proto) with self.simulate_fresh_target(): new_func = pickle.loads(pickled) check_result(new_func) @tag('important') def test_call_with_sig(self): self.run_with_protocols(self.check_call, add_with_sig, 5, (1, 4)) # Compilation has been disabled => float inputs will be coerced to int self.run_with_protocols(self.check_call, add_with_sig, 5, (1.2, 4.2)) @tag('important') def test_call_without_sig(self): self.run_with_protocols(self.check_call, add_without_sig, 5, (1, 4)) self.run_with_protocols(self.check_call, add_without_sig, 5.5, (1.2, 4.3)) # Object mode is enabled self.run_with_protocols(self.check_call, add_without_sig, "abc", ("a", "bc")) @tag('important') def test_call_nopython(self): self.run_with_protocols(self.check_call, add_nopython, 5.5, (1.2, 4.3)) # Object mode is disabled self.run_with_protocols(self.check_call, add_nopython, TypingError, (object(), object())) def test_call_nopython_fail(self): # Compilation fails self.run_with_protocols(self.check_call, add_nopython_fail, TypingError, (1, 2)) def test_call_objmode_with_global(self): self.run_with_protocols(self.check_call, get_global_objmode, 7.5, (2.5,)) def test_call_closure(self): inner = closure(1) self.run_with_protocols(self.check_call, inner, 6, (2, 3)) def check_call_closure_with_globals(self, **jit_args): inner = closure_with_globals(3.0, **jit_args) self.run_with_protocols(self.check_call, inner, 7.0, (4.0,)) def test_call_closure_with_globals_nopython(self): self.check_call_closure_with_globals(nopython=True) def test_call_closure_with_globals_objmode(self): self.check_call_closure_with_globals(forceobj=True) def test_call_closure_calling_other_function(self): inner = closure_calling_other_function(3.0) self.run_with_protocols(self.check_call, inner, 11.0, (4.0, 6.0)) def test_call_closure_calling_other_closure(self): inner = closure_calling_other_closure(3.0) self.run_with_protocols(self.check_call, inner, 8.0, (4.0,)) def test_call_dyn_func(self): # Check serializing a dynamically-created function self.run_with_protocols(self.check_call, dyn_func, 36, (6,)) def test_call_dyn_func_objmode(self): # Same with an object mode function self.run_with_protocols(self.check_call, dyn_func_objmode, 36, (6,)) def test_renamed_module(self): # Issue #1559: using a renamed module (e.g. `import numpy as np`) # should not fail serializing expected = get_renamed_module(0.0) self.run_with_protocols(self.check_call, get_renamed_module, expected, (0.0,)) def test_call_generated(self): self.run_with_protocols(self.check_call, generated_add, 46, (1, 2)) self.run_with_protocols(self.check_call, generated_add, 1j + 7, (1j, 2)) @tag('important') def test_other_process(self): """ Check that reconstructing doesn't depend on resources already instantiated in the original process. """ func = closure_calling_other_closure(3.0) pickled = pickle.dumps(func) code = """if 1: import pickle data = {pickled!r} func = pickle.loads(data) res = func(4.0) assert res == 8.0, res """.format(**locals()) subprocess.check_call([sys.executable, "-c", code]) @tag('important') def test_reuse(self): """ Check that deserializing the same function multiple times re-uses the same dispatcher object. Note that "same function" is intentionally under-specified. """ func = closure(5) pickled = pickle.dumps(func) func2 = closure(6) pickled2 = pickle.dumps(func2) f = pickle.loads(pickled) g = pickle.loads(pickled) h = pickle.loads(pickled2) self.assertIs(f, g) self.assertEqual(f(2, 3), 10) g.disable_compile() self.assertEqual(g(2, 4), 11) self.assertIsNot(f, h) self.assertEqual(h(2, 3), 11) # Now make sure the original object doesn't exist when deserializing func = closure(7) func(42, 43) pickled = pickle.dumps(func) del func gc.collect() f = pickle.loads(pickled) g = pickle.loads(pickled) self.assertIs(f, g) self.assertEqual(f(2, 3), 12) g.disable_compile() self.assertEqual(g(2, 4), 13) def test_imp_deprecation(self): """ The imp module was deprecated in v3.4 in favour of importlib """ code = """if 1: import pickle import warnings with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always', DeprecationWarning) from numba import njit @njit def foo(x): return x + 1 foo(1) serialized_foo = pickle.dumps(foo) for x in w: if 'serialize.py' in x.filename: assert "the imp module is deprecated" not in x.msg """ subprocess.check_call([sys.executable, "-c", code]) if __name__ == '__main__': unittest.main()