# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # coding: utf-8 # pylint: disable=invalid-name, no-member """ctypes library of mxnet and helper functions.""" from __future__ import absolute_import import sys import ctypes import atexit import warnings import inspect import numpy as np from . import libinfo warnings.filterwarnings('default', category=DeprecationWarning) __all__ = ['MXNetError'] #---------------------------- # library loading #---------------------------- if sys.version_info[0] == 3: string_types = str, numeric_types = (float, int, np.generic) integer_types = int # this function is needed for python3 # to convert ctypes.char_p .value back to python str py_str = lambda x: x.decode('utf-8') else: string_types = basestring, numeric_types = (float, int, long, np.generic) integer_types = (int, long) py_str = lambda x: x class _NullType(object): """Placeholder for arguments""" def __repr__(self): return '_Null' _Null = _NullType() class MXNetError(Exception): """Error that will be throwed by all mxnet functions.""" pass class NotImplementedForSymbol(MXNetError): def __init__(self, function, alias, *args): super(NotImplementedForSymbol, self).__init__() self.function = function.__name__ self.alias = alias self.args = [str(type(a)) for a in args] def __str__(self): msg = 'Function {}'.format(self.function) if self.alias: msg += ' (namely operator "{}")'.format(self.alias) if self.args: msg += ' with arguments ({})'.format(', '.join(self.args)) msg += ' is not implemented for Symbol and only available in NDArray.' return msg class MXCallbackList(ctypes.Structure): """Structure that holds Callback information. Passed to CustomOpProp.""" _fields_ = [ ('num_callbacks', ctypes.c_int), ('callbacks', ctypes.POINTER(ctypes.CFUNCTYPE(ctypes.c_int))), ('contexts', ctypes.POINTER(ctypes.c_void_p)) ] def _load_lib(): """Load library by searching possible path.""" lib_path = libinfo.find_lib_path() lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_LOCAL) # DMatrix functions lib.MXGetLastError.restype = ctypes.c_char_p return lib # version number __version__ = libinfo.__version__ # library instance of mxnet _LIB = _load_lib() # type definitions mx_uint = ctypes.c_uint mx_float = ctypes.c_float mx_float_p = ctypes.POINTER(mx_float) mx_real_t = np.float32 NDArrayHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p OpHandle = ctypes.c_void_p CachedOpHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p ExecutorHandle = ctypes.c_void_p DataIterCreatorHandle = ctypes.c_void_p DataIterHandle = ctypes.c_void_p KVStoreHandle = ctypes.c_void_p RecordIOHandle = ctypes.c_void_p RtcHandle = ctypes.c_void_p #---------------------------- # helper function definition #---------------------------- def check_call(ret): """Check the return value of C API call. This function will raise an exception when an error occurs. Wrap every API call with this function. Parameters ---------- ret : int return value from API calls. """ if ret != 0: raise MXNetError(py_str(_LIB.MXGetLastError())) if sys.version_info[0] < 3: def c_str(string): """Create ctypes char * from a Python string. Parameters ---------- string : string type Python string. Returns ------- str : c_char_p A char pointer that can be passed to C API. Examples -------- >>> x = mx.base.c_str("Hello, World") >>> print x.value Hello, World """ return ctypes.c_char_p(string) else: def c_str(string): """Create ctypes char * from a Python string. Parameters ---------- string : string type Python string. Returns ------- str : c_char_p A char pointer that can be passed to C API. Examples -------- >>> x = mx.base.c_str("Hello, World") >>> print x.value Hello, World """ return ctypes.c_char_p(string.encode('utf-8')) def c_array(ctype, values): """Create ctypes array from a Python array. Parameters ---------- ctype : ctypes data type Data type of the array we want to convert to, such as mx_float. values : tuple or list Data content. Returns ------- out : ctypes array Created ctypes array. Examples -------- >>> x = mx.base.c_array(mx.base.mx_float, [1, 2, 3]) >>> print len(x) 3 >>> x[1] 2.0 """ return (ctype * len(values))(*values) def ctypes2buffer(cptr, length): """Convert ctypes pointer to buffer type. Parameters ---------- cptr : ctypes.POINTER(ctypes.c_char) Pointer to the raw memory region. length : int The length of the buffer. Returns ------- buffer : bytearray The raw byte memory buffer. """ if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)): raise TypeError('expected char pointer') res = bytearray(length) rptr = (ctypes.c_char * length).from_buffer(res) if not ctypes.memmove(rptr, cptr, length): raise RuntimeError('memmove failed') return res def ctypes2numpy_shared(cptr, shape): """Convert a ctypes pointer to a numpy array. The resulting NumPy array shares the memory with the pointer. Parameters ---------- cptr : ctypes.POINTER(mx_float) pointer to the memory region shape : tuple Shape of target `NDArray`. Returns ------- out : numpy_array A numpy array : numpy array. """ if not isinstance(cptr, ctypes.POINTER(mx_float)): raise RuntimeError('expected float pointer') size = 1 for s in shape: size *= s dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents)) return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape) def build_param_doc(arg_names, arg_types, arg_descs, remove_dup=True): """Build argument docs in python style. arg_names : list of str Argument names. arg_types : list of str Argument type information. arg_descs : list of str Argument description information. remove_dup : boolean, optional Whether remove duplication or not. Returns ------- docstr : str Python docstring of parameter sections. """ param_keys = set() param_str = [] for key, type_info, desc in zip(arg_names, arg_types, arg_descs): if key in param_keys and remove_dup: continue if key == 'num_args': continue param_keys.add(key) ret = '%s : %s' % (key, type_info) if len(desc) != 0: ret += '\n ' + desc param_str.append(ret) doc_str = ('Parameters\n' + '----------\n' + '%s\n') doc_str = doc_str % ('\n'.join(param_str)) return doc_str def _notify_shutdown(): """Notify MXNet about a shutdown.""" check_call(_LIB.MXNotifyShutdown()) atexit.register(_notify_shutdown) def add_fileline_to_docstring(module, incursive=True): """Append the definition position to each function contained in module. Examples -------- # Put the following codes at the end of a file add_fileline_to_docstring(__name__) """ def _add_fileline(obj): """Add fileinto to a object. """ if obj.__doc__ is None or 'From:' in obj.__doc__: return fname = inspect.getsourcefile(obj) if fname is None: return try: line = inspect.getsourcelines(obj)[-1] except IOError: return obj.__doc__ += '\n\nFrom:%s:%d' % (fname, line) if isinstance(module, str): module = sys.modules[module] for _, obj in inspect.getmembers(module): if inspect.isbuiltin(obj): continue if inspect.isfunction(obj): _add_fileline(obj) if inspect.ismethod(obj): _add_fileline(obj.__func__) if inspect.isclass(obj) and incursive: add_fileline_to_docstring(obj, False) def _as_list(obj): """A utility function that converts the argument to a list if it is not already. Parameters ---------- obj : object Returns ------- If `obj` is a list or tuple, return it. Otherwise, return `[obj]` as a single-element list. """ if isinstance(obj, (list, tuple)): return obj else: return [obj]