# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""Interface for NDArray functions executed by torch backend.
Install Torch and compile with USE_TORCH=1 to use this module."""
from __future__ import absolute_import

import ctypes
import sys
from .base import _LIB
from .base import c_array, c_str_array, c_handle_array, py_str, build_param_doc as _build_param_doc
from .base import mx_uint, mx_float, FunctionHandle
from .base import check_call
from .ndarray import NDArray, _new_empty_handle

try:
    _LUAJIT = ctypes.CDLL("libluajit.so", mode=ctypes.RTLD_GLOBAL)
except OSError:
    _LUAJIT = None

# pylint: disable=too-many-locals, invalid-name
def _make_torch_function(handle):
    """Create a Torch function from the FunctionHandle."""
    # Get the property of function
    n_used_vars = mx_uint()
    n_scalars = mx_uint()
    n_mutate_vars = mx_uint()
    type_mask = ctypes.c_int()
    check_call(_LIB.MXFuncDescribe(
        handle,
        ctypes.byref(n_used_vars),
        ctypes.byref(n_scalars),
        ctypes.byref(n_mutate_vars),
        ctypes.byref(type_mask)))
    n_mutate_vars = n_mutate_vars.value
    n_used_vars = n_used_vars.value
    n_scalars = n_scalars.value
    type_mask = type_mask.value

    # Get the information from the function
    name = ctypes.c_char_p()
    desc = ctypes.c_char_p()
    num_args = mx_uint()
    arg_names = ctypes.POINTER(ctypes.c_char_p)()
    arg_types = ctypes.POINTER(ctypes.c_char_p)()
    arg_descs = ctypes.POINTER(ctypes.c_char_p)()
    ret_type = ctypes.c_char_p()

    check_call(_LIB.MXFuncGetInfo(
        handle, ctypes.byref(name), ctypes.byref(desc),
        ctypes.byref(num_args),
        ctypes.byref(arg_names),
        ctypes.byref(arg_types),
        ctypes.byref(arg_descs),
        ctypes.byref(ret_type)))
    func_name = py_str(name.value)
    if not func_name.startswith('_th_'):
        return None
    narg = int(num_args.value)
    param_str = _build_param_doc(
        [py_str(arg_names[i]) for i in range(narg)],
        [py_str(arg_types[i]) for i in range(narg)],
        [py_str(arg_descs[i]) for i in range(narg)])

    if n_mutate_vars > 1:
        res = ','.join(['res%d '%i for i in range(n_mutate_vars)])
    else:
        res = 'res '
    doc_str = (('Interface for Torch function {name}.\n' +
                'Invoke with\n{res}= mxnet.th.{name}(Parameters)\nor\n'+
                'mxnet.th.{name}({res}, Parameters).\n\n' +
                '{param_str}\n' +
                'Reference: ' +
                'https://github.com/torch/torch7/blob/master/doc/maths.md\n').format(
                    name=func_name[4:], param_str=param_str,
                    res=res))

    def generic_torch_function(*args, **kwargs):
        """Invoke this function by passing in parameters.

        Parameters
        ----------
        *args
            Positional arguments of inputs (both scalar and `NDArray`).

        Returns
        -------
        out : NDArray
            The result NDArray(tuple) of result of computation.
        """
        ndargs = []
        arg_format = ''
        value = ''
        for arg in args:
            if isinstance(arg, NDArray):
                ndargs.append(arg)
                arg_format += 'n'
                value += ','
            elif isinstance(arg, int):
                arg_format += 'i'
                value += str(arg) + ','
            elif isinstance(arg, str):
                arg_format += 's'
                value += str(arg) + ','
            elif isinstance(arg, float):
                arg_format += 'f'
                value += str(arg) + ','
            elif isinstance(arg, bool):
                arg_format += 'b'
                value += str(arg) + ','
        value = value[:-1]
        if len(ndargs) == n_used_vars:
            ndargs = [NDArray(_new_empty_handle()) for _ in range(n_mutate_vars)] + ndargs
            arg_format = 'n'*n_mutate_vars + arg_format
            value = ','*n_mutate_vars + value
        elif len(ndargs) == n_mutate_vars + n_used_vars:
            pass
        else:
            raise AssertionError(('Incorrect number of input NDArrays. ' +
                                  'Need to be either %d (inputs) or %d ' +
                                  '(output buffer) + %d (input)') %
                                 (n_used_vars, n_mutate_vars, n_used_vars))

        kwargs['format'] = arg_format
        kwargs['args'] = value

        for k in kwargs:
            kwargs[k] = str(kwargs[k])

        check_call(_LIB.MXFuncInvokeEx(
            handle,
            c_handle_array(ndargs[n_mutate_vars:]), # pylint: disable=invalid-slice-index
            c_array(mx_float, []),
            c_handle_array(ndargs[:n_mutate_vars]),   # pylint: disable=invalid-slice-index
            ctypes.c_int(len(kwargs)),
            c_str_array(kwargs.keys()),
            c_str_array(kwargs.values())))

        if n_mutate_vars == 1:
            return ndargs[0]
        else:
            return ndargs[:n_mutate_vars] # pylint: disable=invalid-slice-index

    # End of function declaration
    ret_function = generic_torch_function
    ret_function.__name__ = func_name[4:]
    ret_function.__doc__ = doc_str
    return ret_function

# pylint: enable=too-many-locals, invalid-name

def _init_torch_module():
    """List and add all the torch backed ndarray functions to current module."""
    plist = ctypes.POINTER(FunctionHandle)()
    size = ctypes.c_uint()
    check_call(_LIB.MXListFunctions(ctypes.byref(size),
                                    ctypes.byref(plist)))

    module_obj = sys.modules[__name__]
    for i in range(size.value):
        hdl = FunctionHandle(plist[i])
        function = _make_torch_function(hdl)
        # if function name starts with underscore, register as static method of NDArray
        if function is not None:
            setattr(module_obj, function.__name__, function)

# Initialize the NDArray module
_init_torch_module()