/*! * Copyright (c) 2015 by Contributors * \file half.h * \brief definition of half (float16) type. * * \author Junyuan Xie */ #ifndef MSHADOW_HALF_H_ #define MSHADOW_HALF_H_ #include "./base.h" #if MSHADOW_USE_F16C #include <x86intrin.h> #endif // MSHADOW_USE_F16C // This flag dictates rounding for the float2half() routine only (used generally on Windows), // not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even. #ifndef MSHADOW_HALF_ROUND_TO_NEAREST #define MSHADOW_HALF_ROUND_TO_NEAREST 1 #endif #if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) #define MSHADOW_CUDA_HALF 1 #include <cuda_fp16.h> #if defined(__CUDA_ARCH__) /*! \brief __half2float_warp */ __host__ __device__ float __half2float_warp(const volatile __half& h) { /* NOLINT(*) */ __half val; #if CUDA_VERSION >= 9000 val = const_cast<__half&>(h); #else val.x = h.x; #endif return __half2float(val); } #endif #else #define MSHADOW_CUDA_HALF 0 #endif /*! \brief namespace for mshadow */ namespace mshadow { /* \brief name space for host/device portable half-precision floats */ namespace half { #define MSHADOW_HALF_OPERATOR(RTYPE, OP) \ MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \ return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ } \ template<typename T> \ MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \ return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ } \ template<typename T> \ MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \ return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ } #define MSHADOW_HALF_ASSIGNOP(AOP, OP) \ template<typename T> \ MSHADOW_XINLINE half_t operator AOP (const T& a) { \ return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \ } \ template<typename T> \ MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \ return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \ } #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) #define MSHADOW_HALF_CONVERSIONOP(T) \ MSHADOW_XINLINE operator T() const { \ return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \ } \ MSHADOW_XINLINE operator T() const volatile { \ return T(__half2float_warp(cuhalf_)); /* NOLINT(*)*/ \ } #elif(MSHADOW_USE_F16C) #define MSHADOW_HALF_CONVERSIONOP(T) \ MSHADOW_XINLINE operator T() const { \ return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \ } \ MSHADOW_XINLINE operator T() const volatile { \ return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \ } #else #define MSHADOW_HALF_CONVERSIONOP(T) \ MSHADOW_XINLINE operator T() const { \ return T(half2float(half_)); /* NOLINT(*)*/ \ } \ MSHADOW_XINLINE operator T() const volatile { \ return T(half2float(half_)); /* NOLINT(*)*/ \ } #endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) class MSHADOW_ALIGNED(2) half_t { public: union { uint16_t half_; #if MSHADOW_CUDA_HALF __half cuhalf_; #endif // MSHADOW_CUDA_HALF }; static MSHADOW_XINLINE half_t Binary(uint16_t value) { half_t res; res.half_ = value; return res; } MSHADOW_XINLINE half_t() {} #if MSHADOW_CUDA_HALF MSHADOW_XINLINE explicit half_t(const __half& value) { cuhalf_ = value; } #endif // MSHADOW_CUDA_HALF MSHADOW_XINLINE half_t(const float& value) { constructor(value); } MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); } MSHADOW_XINLINE explicit half_t(const int8_t& value) { constructor(value); } MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); } MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); } MSHADOW_XINLINE explicit half_t(const uint32_t& value) { constructor(value); } MSHADOW_XINLINE explicit half_t(const int64_t& value) { constructor(value); } MSHADOW_XINLINE explicit half_t(const uint64_t& value) { constructor(value); } MSHADOW_HALF_CONVERSIONOP(float) MSHADOW_HALF_ASSIGNOP(+=, +) MSHADOW_HALF_ASSIGNOP(-=, -) MSHADOW_HALF_ASSIGNOP(*=, *) MSHADOW_HALF_ASSIGNOP(/=, /) MSHADOW_XINLINE half_t operator+() { return *this; } MSHADOW_XINLINE half_t operator-() { return half_t(-float(*this)); // NOLINT(*) } MSHADOW_XINLINE half_t operator=(const half_t& a) { half_ = a.half_; return a; } template<typename T> MSHADOW_XINLINE half_t operator=(const T& a) { return *this = half_t(a); /* NOLINT(*)*/ } MSHADOW_XINLINE half_t operator=(const half_t& a) volatile { half_ = a.half_; return a; } template<typename T> MSHADOW_XINLINE half_t operator=(const T& a) volatile { return *this = half_t(a); /* NOLINT(*)*/ } private: union Bits { float f; int32_t si; uint32_t ui; }; static int const fp16FractionBits = 10; static int const fp32FractionBits = 23; static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000 static int const shift = fp32FractionBits - fp16FractionBits; // == 13 static int const shiftSign = 16; static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15) static int32_t const infN = 0x7F800000; // flt32 infinity static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16 static int32_t const signN = 0x80000000; // flt32 sign bit static int32_t const infC = infN >> shift; static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 static int32_t const maxC = maxN >> shift; static int32_t const minC = minN >> shift; static int32_t const signC = signN >> shiftSign; // flt16 sign bit static int32_t const mulN = 0x52000000; // (1 << 23) / minN static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted static int32_t const norC = 0x00400; // min flt32 normal down shifted static int32_t const maxD = infC - maxC - 1; static int32_t const minD = minC - subC - 1; MSHADOW_XINLINE uint16_t float2half(const float& value) const { Bits v; v.f = value; uint32_t sign = v.si & signN; // grab sign bit v.si ^= sign; // clear sign bit from v sign >>= shiftSign; // logical shift sign to fp16 position if (v.si <= maxZ) { // Handle eventual zeros here to ensure vshift will not exceed 32 below. v.ui = 0; } else if (v.si < minN) { // Handle denorms uint32_t exp32 = v.ui >> fp32FractionBits; int32_t exp16 = exp32 - expAdjust; // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. // Smaller (so negative) exp16 values should result in greater right shifts. uint32_t vshift = 1 - exp16; uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); v.ui = significand >> vshift; // The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is // when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional // bits to the right of the lsb are 1000... (including flt32 significand bits // that may be lost during the above vshift). The first term below will always // be true for vshift >=12 (since even the 'hidden bit' has been shifted to the // right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make // the proper test of the flt32 significand bits, including those lost during the vshift. #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 // Rounding may increase the exponent to 1, but that's OK. v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; #endif } else if (v.si <= maxN) { // Handle norms #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 // Rounding may increase the exponent, possibly creating an inf, but that's OK. v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; #endif v.ui -= expAdjust << fp32FractionBits; } else if (v.si <= infN) { v.si = infN; } else if (v.si < nanN) { v.si = nanN; } v.ui >>= shift; return sign | (v.ui & 0x7fff); } // Same as above routine, except for addition of volatile keyword MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*) Bits v; v.f = value; uint32_t sign = v.si & signN; // grab sign bit v.si ^= sign; // clear sign bit from v sign >>= shiftSign; // logical shift sign to fp16 position if (v.si <= maxZ) { // Handle eventual zeros here to ensure vshift will not exceed 32 below. v.ui = 0; } else if (v.si < minN) { // Handle denorms uint32_t exp32 = v.ui >> fp32FractionBits; int32_t exp16 = exp32 - expAdjust; // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. // Smaller (so negative) exp16 values should result in greater right shifts. uint32_t vshift = 1 - exp16; uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); v.ui = significand >> vshift; #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 // Rounding may increase the exponent to 1, but that's OK. v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; #endif } else if (v.si <= maxN) { // Handle norms #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 // Rounding may increase the exponent, possibly creating an inf, but that's OK. v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; #endif v.ui -= expAdjust << fp32FractionBits; } else if (v.si <= infN) { v.si = infN; } else if (v.si < nanN) { v.si = nanN; } v.ui >>= shift; return sign | (v.ui & 0x7fff); } MSHADOW_XINLINE float half2float(const uint16_t& value) const { Bits v; v.ui = value; int32_t sign = v.si & signC; v.si ^= sign; sign <<= shiftSign; v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); Bits s; s.si = mulC; s.f *= v.si; int32_t mask = -(norC > v.si); v.si <<= shift; v.si ^= (s.si ^ v.si) & mask; v.si |= sign; return v.f; } MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile { // NOLINT(*) Bits v; v.ui = value; int32_t sign = v.si & signC; v.si ^= sign; sign <<= shiftSign; v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); Bits s; s.si = mulC; s.f *= v.si; int32_t mask = -(norC > v.si); v.si <<= shift; v.si ^= (s.si ^ v.si) & mask; v.si |= sign; return v.f; } template<typename T> MSHADOW_XINLINE void constructor(const T& value) { #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) cuhalf_ = __float2half(float(value)); // NOLINT(*) #elif(MSHADOW_USE_F16C) half_ = _cvtss_sh(static_cast<float>(value), 0); #else /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */ half_ = float2half(float(value)); // NOLINT(*) #endif /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */ } }; /*! \brief overloaded + operator for half_t */ MSHADOW_HALF_OPERATOR(half_t, +) /*! \brief overloaded - operator for half_t */ MSHADOW_HALF_OPERATOR(half_t, -) /*! \brief overloaded * operator for half_t */ MSHADOW_HALF_OPERATOR(half_t, *) /*! \brief overloaded / operator for half_t */ MSHADOW_HALF_OPERATOR(half_t, /) /*! \brief overloaded > operator for half_t */ MSHADOW_HALF_OPERATOR(bool, >) /*! \brief overloaded < operator for half_t */ MSHADOW_HALF_OPERATOR(bool, <) /*! \brief overloaded >= operator for half_t */ MSHADOW_HALF_OPERATOR(bool, >=) /*! \brief overloaded <= operator for half_t */ MSHADOW_HALF_OPERATOR(bool, <=) #define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF); #define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF); } // namespace half } // namespace mshadow #endif // MSHADOW_HALF_H_