# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Interface to runtime cuda kernel compile module.""" from __future__ import absolute_import import ctypes from .base import _LIB, NDArrayHandle, RtcHandle, mx_uint, c_array, check_call class Rtc(object): """MXRtc object in mxnet. This class allow you to write CUDA kernels in Python and call them with NDArray. Parameters ---------- name : str Name of the kernel. inputs : tuple of (str, mxnet.ndarray) List of input names and ndarray. outputs : tuple of (str, mxnet.ndarray) List of output names and ndarray. kernel : str The actual kernel code. Note that this is only the body of the kernel, i.e. after { and before }. Rtc will decorate the kernel. For example, if ``name = "mykernel"`` and inputs = [('x', mx.nd.zeros((10,)))] outputs = [('y', mx.nd.zeros((10,)))] kernel = "y[threadIdx.x] = x[threadIdx.x];", then the compiled kernel will be: extern "C" __global__ mykernel(float *x, float *y) { const int x_ndim = 1; const int x_dims = { 10 }; const int y_ndim = 1; const int y_dims = { 10 }; y[threadIdx.x] = x[threadIdx.x]; } """ def __init__(self, name, inputs, outputs, kernel): self.handle = RtcHandle() input_names = ctypes.cast(c_array(ctypes.c_char_p, [i[0] for i in inputs]), ctypes.POINTER(ctypes.c_char_p)) output_names = ctypes.cast(c_array(ctypes.c_char_p, [i[0] for i in outputs]), ctypes.POINTER(ctypes.c_char_p)) input_nds = ctypes.cast(c_array(NDArrayHandle, [i[1].handle for i in inputs]), ctypes.POINTER(NDArrayHandle)) output_nds = ctypes.cast(c_array(NDArrayHandle, [i[1].handle for i in outputs]), ctypes.POINTER(NDArrayHandle)) check_call(_LIB.MXRtcCreate(ctypes.c_char_p(name), mx_uint(len(inputs)), mx_uint(len(outputs)), input_names, output_names, input_nds, output_nds, ctypes.c_char_p(kernel), ctypes.byref(self.handle))) def __del__(self): check_call(_LIB.MXRtcFree(self.handle)) def push(self, inputs, outputs, grid_dims, block_dims): """Run the kernel. Parameters ---------- inputs : list of NDArray List of inputs. Can contain different NDArrays than those used for the constructor, but its elements must have the same shapes and appear in the same order. outputs : list of NDArray List of outputs. Can contain different ndarrays than used for the constructor, but must have the same shapes and appear in the same order. grid_dims : tuple of 3 uint Grid dimension for kernel launch. block_dims : tuple of 3 uint Block dimension for kernel launch. """ input_nds = ctypes.cast(c_array(NDArrayHandle, [i.handle for i in inputs]), ctypes.POINTER(NDArrayHandle)) output_nds = ctypes.cast(c_array(NDArrayHandle, [i.handle for i in outputs]), ctypes.POINTER(NDArrayHandle)) check_call(_LIB.MXRtcPush(self.handle, mx_uint(len(inputs)), mx_uint(len(outputs)), input_nds, output_nds, mx_uint(grid_dims[0]), mx_uint(grid_dims[1]), mx_uint(grid_dims[2]), mx_uint(block_dims[0]), mx_uint(block_dims[1]), mx_uint(block_dims[2])))