/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file int_operator.h
 * \brief Additional useful operators for integer.
 */
#ifndef TVM_ARITH_INT_OPERATOR_H_
#define TVM_ARITH_INT_OPERATOR_H_

#include <limits>
#include <utility>

namespace tvm {
namespace arith {

/*!
 * \brief Check if an integer op with operand x, y will overflow.
 * \param x The left operand.
 * \param y The left operand.
 * \param min_value The minimum value of the domain.
 * \param max_value The maximum value of the domain.
 * \return Whether overflow can happen.
 * \tparam Op The integer operator.
 */
template <typename Op>
inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
  return false;
}

template <>
inline bool WillOverflow<tir::AddNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
  if ((y > 0) && (x > max_value - y)) return true;
  if ((y < 0) && (x < min_value - y)) return true;
  return false;
}

template <>
inline bool WillOverflow<tir::SubNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
  if ((y > 0) && (x < min_value + y)) return true;
  if ((y < 0) && (x > max_value + y)) return true;
  return false;
}

template <>
inline bool WillOverflow<tir::MulNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
  if (y == 0) return false;
  if (y > 0) {
    if (x < min_value / y) return true;
    if (x > max_value / y) return true;
  } else {
    if (y == -1 && x == std::numeric_limits<int64_t>::min()) return true;
    if (x > min_value / y) return true;
    if (x < max_value / y) return true;
  }
  return false;
}

template <>
inline bool WillOverflow<tir::ModNode>(int64_t x, int64_t y, int64_t min_value, int64_t max_value) {
  return y == 0;
}

/*!
 * \brief Peform trunc division of two integers.
 * \param x The left operand.
 * \param y The right operand.
 * \return the result.
 */
inline int64_t truncdiv(int64_t x, int64_t y) { return x / y; }

/*!
 * \brief Compute the truncdiv remainder of two integers.
 * \param x The left operand.
 * \param y The right operand.
 * \return the result.
 */
inline int64_t truncmod(int64_t x, int64_t y) { return x % y; }

/*!
 * \brief Peform floor division of two integers.
 * \param x The left operand.
 * \param y The right operand.
 * \return the result.
 */
inline int64_t floordiv(int64_t x, int64_t y) {
  int64_t rdiv = x / y;
  int64_t rmod = x % y;
  bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0);
  return is_floor_div ? rdiv : (rdiv - 1);
}

/*!
 * \brief Compute the floordiv remainder of two integers.
 * \param x The left operand.
 * \param y The right operand.
 * \return the result.
 */
inline int64_t floormod(int64_t x, int64_t y) {
  int64_t rmod = x % y;
  bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0);
  return is_floor_div ? rmod : rmod + y;
}

/*!
 * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
 * \param a The first coefficient.
 * \param b The second coefficient.
 * \param x The solution of x.
 * \param y The solution of y.
 * \return The GCD of a and b.
 */
inline int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) {
  // Extended Euclidean algorithm
  // if a < 0, the problem can be convert into
  // |a|* (-x) + b * y = gcd(|a|, b)
  //
  // initial condition:
  // a * 0 + b * 1 = b
  // a * 1 + b * 0 = a
  int64_t s = 0, old_s = 1;
  int64_t r = b, old_r = a >= 0 ? a : -a;
  // Iteration (r2 < r1):
  // a * x1 + b * y1 = r1
  // a * x2 + b * y2 = r2
  // The above two eqs can derive the following eq (q = r1 / r2)
  // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3
  // Because r3 < r2, the iteration can eventually terminate
  while (r != 0) {
    int64_t q = old_r / r;
    int64_t tmp = old_r;
    old_r = r;
    r = tmp - q * r;
    tmp = old_s;
    old_s = s;
    s = tmp - q * s;
  }

  *x = a >= 0 ? old_s : -old_s;
  if (b != 0) {
    *y = (old_r - (*x) * a) / b;
  } else {
    *y = 1;
  }

  return old_r;
}

/*!
 * \brief Take GCD of a and b.
 * \param a The first operand.
 * \param b The second operand.
 * \return The result.
 */
inline int64_t ZeroAwareGCD(int64_t a, int64_t b) {
  if (a < 0) a = -a;
  if (b < 0) b = -b;
  if (a < b) std::swap(a, b);
  if (b == 0) return a;
  // perform GCD (greatest common divisor)
  // ax + by = gcd(a, b) z if a != 0, b != 0
  while (a % b != 0) {
    a = a % b;
    std::swap(a, b);
  }
  return b;
}

/*!
 * \brief Calculate the least common multiple for two values.
 * \param a an integer number
 * \param b an integer number
 * \return the least common multiple.
 */
inline int64_t LeastCommonMultiple(int64_t a, int64_t b) {
  int64_t x, y;
  return (a * b) / ExtendedEuclidean(a, b, &x, &y);
}

}  // namespace arith
}  // namespace tvm
#endif  // TVM_ARITH_INT_OPERATOR_H_