/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file 3rdparty/byodt/my-custom-datatype.cc * \brief Example Custom Datatype with the Bring Your Own Datatypes (BYODT) framework. * This is a toy example that under the hood simulates floats. * * Users interested in using the BYODT framework can use this file as a template. * * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist? */ #include <tvm/runtime/c_runtime_api.h> #include <cmath> #include <cstdint> #include <limits> // Custom datatypes are stored as bits in a uint of the appropriate bit length. // Thus, when TVM calls these C functions, // the arguments of are uints that need to reinterpreted as your custom datatype. // // When returning, your custom datatype needs to be re-wrapped into a uint, // which can be thought of as just a wrapper for the raw bits that represent your custom datatype. template <class T> TVM_DLL T Uint32ToCustom32(uint32_t in) { // This is a helper function to interpret the uint as your custom dataype. // The following line should be replaced with the appropriate function // that interprets the bits in `in` and returns your custom datatype T* custom = reinterpret_cast<T*>(&in); return *custom; } template <class T> TVM_DLL uint32_t Custom32ToUint32(T in) { // This is a helper function to wrap your custom datatype in a uint. // the following line should be replaced with the appropriate function // that converts your custom datatype into a uint uint32_t* bits = reinterpret_cast<uint32_t*>(&in); return *bits; } extern "C" { TVM_DLL uint32_t MinCustom32() { // return minimum representable value float min = std::numeric_limits<float>::lowest(); return Custom32ToUint32<float>(min); } TVM_DLL float Custom32ToFloat(uint32_t in) { // cast from custom datatype to float float custom_datatype = Uint32ToCustom32<float>(in); // our custom datatype is float, so the following redundant cast to float // is to remind users to cast their own custom datatype to float return static_cast<float>(custom_datatype); } TVM_DLL uint32_t FloatToCustom32(float in) { // cast from float to custom datatype return Custom32ToUint32<float>(in); } TVM_DLL uint32_t Custom32Add(uint32_t a, uint32_t b) { // add operation float acustom = Uint32ToCustom32<float>(a); float bcustom = Uint32ToCustom32<float>(b); return Custom32ToUint32<float>(acustom + bcustom); } TVM_DLL uint32_t Custom32Sub(uint32_t a, uint32_t b) { // subtract float acustom = Uint32ToCustom32<float>(a); float bcustom = Uint32ToCustom32<float>(b); return Custom32ToUint32<float>(acustom - bcustom); } TVM_DLL uint32_t Custom32Mul(uint32_t a, uint32_t b) { // multiply float acustom = Uint32ToCustom32<float>(a); float bcustom = Uint32ToCustom32<float>(b); return Custom32ToUint32<float>(acustom * bcustom); } TVM_DLL uint32_t Custom32Div(uint32_t a, uint32_t b) { // divide float acustom = Uint32ToCustom32<float>(a); float bcustom = Uint32ToCustom32<float>(b); return Custom32ToUint32<float>(acustom / bcustom); } TVM_DLL uint32_t Custom32Max(uint32_t a, uint32_t b) { // max float acustom = Uint32ToCustom32<float>(a); float bcustom = Uint32ToCustom32<float>(b); return Custom32ToUint32<float>(acustom > bcustom ? acustom : bcustom); } TVM_DLL uint32_t Custom32Sqrt(uint32_t a) { // sqrt float acustom = Uint32ToCustom32<float>(a); return Custom32ToUint32<float>(sqrt(acustom)); } TVM_DLL uint32_t Custom32Exp(uint32_t a) { // exponential float acustom = Uint32ToCustom32<float>(a); return Custom32ToUint32<float>(exp(acustom)); } TVM_DLL uint32_t Custom32Log(uint32_t a) { // log float acustom = Uint32ToCustom32<float>(a); return Custom32ToUint32<float>(log(acustom)); } TVM_DLL uint32_t Custom32Sigmoid(uint32_t a) { // sigmoid float acustom = Uint32ToCustom32<float>(a); float one = 1.0f; return Custom32ToUint32<float>(one / (one + exp(-acustom))); } TVM_DLL uint32_t Custom32Tanh(uint32_t a) { // tanh float acustom = Uint32ToCustom32<float>(a); return Custom32ToUint32<float>(tanh(acustom)); } }