# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import random
import re
import threading

import numpy as np
import pytest

import tvm
import tvm.testing
from tvm import relay, te
from tvm.topi.math import cast


def randint_loguniform(low=1, high=32768, size=None):
    logN = np.random.uniform(low=np.log(low), high=np.log(high), size=size)
    N = np.exp(logN).astype(int)
    return np.unique(N)


dtype = tvm.testing.parameter("float32", "int32", "float16", "int8")
fuzz_arr_size = tvm.testing.parameter(*randint_loguniform(size=25))


# Explicitly specify a target, as this test is looking at the
# generated shader code, and is not running on an actual device.
@tvm.testing.parametrize_targets(
    " ".join(
        [
            "vulkan",
            "-supports_int8=1",
            "-supports_8bit_buffer=1",
            "-supports_storage_buffer_storage_class=1",
            "-supports_float16=1",
            "-supports_16bit_buffer=1",
        ]
    )
)
def test_vector_comparison(target, dtype):
    n = (1024,)
    A = te.placeholder(n, dtype=dtype, name="A")
    B = te.compute(
        A.shape,
        lambda i: tvm.tir.Select(
            A[i] >= 0, A[i] + tvm.tir.const(1, dtype), tvm.tir.const(0, dtype)
        ),
        name="B",
    )
    s = te.create_schedule(B.op)

    (bx, tx) = s[B].split(s[B].op.axis[0], factor=128)
    (tx, vx) = s[B].split(tx, factor=4)
    s[B].bind(bx, te.thread_axis("blockIdx.x"))
    s[B].bind(tx, te.thread_axis("threadIdx.x"))
    s[B].vectorize(vx)
    f = tvm.build(s, [A, B], target)

    # Verify we generate the boolx4 type declaration and the OpSelect
    # v4{float,half,int} instruction
    assembly = f.imported_modules[0].get_source()
    matches = re.findall("%v4bool = OpTypeVector %bool 4", assembly)
    assert len(matches) == 1
    matches = re.findall("OpSelect %v4.*", assembly)
    assert len(matches) == 1


def test_array_copy(dev, dtype, fuzz_arr_size):
    a_np = np.random.uniform(size=(fuzz_arr_size,)).astype(dtype)
    a = tvm.nd.empty((fuzz_arr_size,), dtype, dev).copyfrom(a_np)
    b_np = a.numpy()
    tvm.testing.assert_allclose(a_np, b_np)
    tvm.testing.assert_allclose(a_np, a.numpy())


@tvm.testing.exclude_targets("llvm")
def test_array_vectorize_add(target, dev, dtype):
    arr_size = 64
    lanes = 2

    num_thread = 8

    A = te.placeholder((arr_size,), name="A", dtype="%sx%d" % (dtype, lanes))
    B = te.compute((arr_size,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B")
    s = te.create_schedule(B.op)
    xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
    s[B].bind(xo, te.thread_axis("blockIdx.x"))
    s[B].bind(xi, te.thread_axis("threadIdx.x"))
    fun = tvm.build(s, [A, B], target)
    a = tvm.nd.empty((arr_size,), A.dtype, dev).copyfrom(np.random.uniform(size=(arr_size, lanes)))
    c = tvm.nd.empty((arr_size,), B.dtype, dev)
    fun(a, c)
    tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1)


@tvm.testing.parametrize_targets("vulkan")
def test_vulkan_stress(target, dev):
    """
    Launch a randomized test with multiple kernels per stream, multiple uses of
    kernels per stream, over multiple threads.
    """

    n = 1024
    num_thread = 64

    def run_stress():
        def worker():
            A = te.placeholder((n,), name="A", dtype="float32")
            B = te.placeholder((n,), name="B", dtype="float32")
            functions = [
                (
                    lambda: te.compute((n,), lambda i: 2 * A[i] + 3 * B[i]),
                    lambda a, b: 2 * a + 3 * b,
                ),
                (lambda: te.compute((n,), lambda i: A[i] + B[i]), lambda a, b: a + b),
                (lambda: te.compute((n,), lambda i: A[i] + 2 * B[i]), lambda a, b: a + 2 * b),
            ]

            def build_f(f_ref):
                (C_f, ref) = f_ref
                C = C_f()
                s = te.create_schedule(C.op)
                xo, xi = s[C].split(C.op.axis[0], factor=num_thread)
                s[C].bind(xo, te.thread_axis("blockIdx.x"))
                s[C].bind(xi, te.thread_axis("threadIdx.x"))
                fun = tvm.build(s, [A, B, C], target)
                return (fun, ref)

            fs = [
                build_f(random.choice(functions)) for _ in range(np.random.randint(low=1, high=10))
            ]
            a = tvm.nd.empty((n,), A.dtype, dev).copyfrom(np.random.uniform(size=(n,)))
            b = tvm.nd.empty((n,), B.dtype, dev).copyfrom(np.random.uniform(size=(n,)))
            cs = [tvm.nd.empty((n,), A.dtype, dev) for _ in fs]
            for ((f, _), c) in zip(fs, cs):
                f(a, b, c)

            for ((_, ref), c) in zip(fs, cs):
                tvm.testing.assert_allclose(c.numpy(), ref(a.numpy(), b.numpy()))

        ts = [threading.Thread(target=worker) for _ in range(np.random.randint(1, 10))]
        for t in ts:
            t.start()
        for t in ts:
            t.join()

    run_stress()


@tvm.testing.exclude_targets("llvm")
def test_vulkan_bool_load(target, dev):
    arr_size = 1024

    target = tvm.target.Target(target)
    if target.kind.name == "vulkan":
        supports_int8_buffer = target.attrs.get("supports_int8", False) and target.attrs.get(
            "supports_8bit_buffer", False
        )
        if not supports_int8_buffer:
            pytest.xfail(
                "Vulkan target does not support int8 buffer access, used to transfer booleans"
            )

    def do_copy(A, B, n):
        ib = tvm.tir.ir_builder.create()
        A = ib.buffer_ptr(A)
        B = ib.buffer_ptr(B)

        tx = te.thread_axis("threadIdx.x")
        bx = te.thread_axis("blockIdx.x")

        max_threads = 32
        ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads - 1, max_threads))
        ib.scope_attr(tx, "thread_extent", max_threads)
        tid = bx * max_threads + tx

        with ib.if_scope(tid < n):
            B[tid] = cast(A[tid], "int32")

        return ib.get()

    A = te.placeholder((arr_size,), name="A", dtype="bool")
    B = te.placeholder((arr_size,), name="B", dtype="int32")

    B = te.extern(
        A.shape,
        [A],
        lambda ins, outs: do_copy(ins[0], outs[0], arr_size),
        name="bool_copy_ir",
        dtype="int32",
    )
    s = te.create_schedule(B.op)

    with tvm.transform.PassContext(opt_level=3):
        func = tvm.build(s, [A, B], target)

    a_np = np.random.uniform(size=arr_size) > 0.5
    b_np = np.zeros((arr_size,), dtype="int32")
    a = tvm.nd.array(a_np, dev)
    b = tvm.nd.array(b_np, dev)
    func(a, b)
    ref = a_np.astype(np.int32)
    tvm.testing.assert_allclose(b.numpy(), ref)


def check_mod(target, dev, mod, x_np, res_np):
    res = relay.create_executor("vm", mod=mod, device=dev, target=target).evaluate()(x_np).numpy()
    tvm.testing.assert_allclose(res, res_np, atol=1e-5)


def test_sqrt(target, dev):
    # Three 32 bit pushconstants: any_dim, stride, stride
    dtype = "float32"
    x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
    mod = tvm.IRModule()
    mod["main"] = relay.Function([x], relay.sqrt(x))
    x_np = np.random.uniform(size=(10,)).astype(dtype)
    res_np = np.sqrt(x_np)

    check_mod(target, dev, mod, x_np, res_np)


def test_argsort(target, dev):
    # One 64 bit and one 32 bit constants
    dtype = "int32"
    x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
    mod = tvm.IRModule()
    mod["main"] = relay.Function([x], relay.argsort(x))
    x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype)
    res_np = np.argsort(x_np)

    check_mod(target, dev, mod, x_np, res_np)


def test_cumsum(target, dev):
    # One 64 bit and one 32 bit constants
    dtype = "int32"
    x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
    mod = tvm.IRModule()
    mod["main"] = relay.Function([x], relay.cumsum(x))
    x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype)
    res_np = np.cumsum(x_np)

    check_mod(target, dev, mod, x_np, res_np)


def test_unique(target, dev):
    dtype = "int32"
    x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
    mod = tvm.IRModule()
    [unique, _, _, num_unique] = relay.unique(x, is_sorted=True)
    mod["main"] = relay.Function([x], relay.op.strided_slice(unique, begin=[0], end=num_unique))
    x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype)
    res_np = np.unique(x_np)
    check_mod(target, dev, mod, x_np, res_np)


vulkan_parameter_impl = tvm.testing.parameter("push_constants", "ubo")
vulkan_parameter_dtype = tvm.testing.parameter("int32", "float32", "int64")

# Only run on vulkan because extremely large numbers of input
# parameters can crash cuda/llvm compiler.
@tvm.testing.parametrize_targets("vulkan -from_device=0")
def test_vulkan_constant_passing(target, dev, vulkan_parameter_impl, vulkan_parameter_dtype):
    target = tvm.target.Target(target)
    dtype = vulkan_parameter_dtype

    if not target.attrs.get("supports_int64", False):
        pytest.xfail("Vulkan target does not support Int64 variables")

    # f_add has 3+num_int_params scalar parameters.  The other three
    # are length_n, stride1, and stride2.
    if vulkan_parameter_impl == "push_constants":
        # 4 params, 32 bytes.  Within 128-byte spec-guaranteed size of
        # push constants.  Uses push constants.
        num_int_params = 1
    else:
        # 24 params, 192 bytes.  May be above spec-guaranteed size of 128
        # bytes for push constants.  Uses either push constants or UBO,
        # depending on the device.
        max_push_constants_size = int(target.attrs.get("max_push_constants_size", 128))
        max_int_params_in_push = max_push_constants_size // 8 - 3
        num_int_params = max_int_params_in_push + 1

    n = te.var("n")
    scalars = [te.var("scale{}".format(i), dtype=dtype) for i in range(num_int_params)]
    scalar_sum = scalars[0]
    for s in scalars[1:]:
        scalar_sum += s

    A = te.placeholder((n,), name="A", dtype=dtype)
    B = te.compute(A.shape, lambda i: scalar_sum + A[i], name="B")

    s = te.create_schedule(B.op)
    xo, xi = s[B].split(B.op.axis[0], factor=64)
    s[B].bind(xo, te.thread_axis("blockIdx.x"))
    s[B].bind(xi, te.thread_axis("threadIdx.x"))
    f_add = tvm.build(s, scalars + [A, B], target)

    n = 1024
    scalars = np.array([1 for _ in scalars]).astype(dtype)
    a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
    b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)
    f_add(*scalars, a, b)

    tvm.testing.assert_allclose(a.numpy() + sum(scalars), b.numpy())


def test_vulkan_while_if(target, dev):
    target = tvm.target.Target(target)

    def do_compute(A, B, n):
        ib = tvm.tir.ir_builder.create()
        A = ib.buffer_ptr(A)
        B = ib.buffer_ptr(B)

        if "gpu" in target.keys:
            ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0)

        iterations = ib.allocate("int32", (1,), name="iterations", scope="local")
        iterations[0] = 0
        B[0] = 0

        # WhileNode's condition is re-evaluated every loop.  The
        # if_then_else block introduces additional labels/blocks that
        # must be kept separate from the WhileNode's block.
        loop_condition = iterations[0] < tvm.tir.if_then_else(A[0] > 0, 10, 20)
        with ib.while_loop(loop_condition):
            iterations[0] += 1
            B[0] += iterations[0]

        return ib.get()

    n = 1
    dtype = "int32"
    A = te.placeholder((n,), name="A", dtype=dtype)

    B = te.extern(
        A.shape,
        [A],
        lambda ins, outs: do_compute(ins[0], outs[0], n),
        dtype=dtype,
    )
    s = te.create_schedule(B.op)

    # Point of failure would be here, at tvm.build.
    with tvm.transform.PassContext(opt_level=3):
        func = tvm.build(s, [A, B], target)

    a = tvm.nd.array(np.array([5], dtype=A.dtype), dev)
    b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
    func(a, b)
    tvm.testing.assert_allclose(b.numpy(), [55])

    a = tvm.nd.array(np.array([-5], dtype=A.dtype), dev)
    b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
    func(a, b)
    tvm.testing.assert_allclose(b.numpy(), [210])


@tvm.testing.exclude_targets("llvm")
def test_vulkan_local_threadidx(target, dev):
    # To access the thread index, the vulkan runtime accesses a global
    # array of thread indices, storing the result in a local variable.
    # In CUDA, these are the built-in threadIdx.x variables, which are
    # globally accessible.  In vulkan, these local variables must be
    # defined inside a function, but are hoisted up to the function
    # header to mimic the global CUDA semantics.  Before this
    # hoisting, this test could trigger spvValidate errors for
    # potentially undeclared variables.

    def do_compute(A, B, n):
        ib = tvm.tir.ir_builder.create()
        A = ib.buffer_ptr(A)
        B = ib.buffer_ptr(B)

        # One single declaration of te.thread_axis.
        tx = te.thread_axis("threadIdx.x")

        with ib.for_range(0, 1):
            # Used inside a for-loop scope, defines local thread_id
            # variable.
            ib.scope_attr(tx, "thread_extent", 16)
            B[tx + 0] = A[tx + 0]

        with ib.for_range(0, 1):
            # Used in next scope.  If local variable defined at point
            # of use instead of function header, will fail spvValidate
            # for access of out-of-scope local variable.
            ib.scope_attr(tx, "thread_extent", 16)
            B[tx + 16] = A[tx + 16]

        return ib.get()

    n = te.var("n")
    A = te.placeholder((n,), name="A", dtype="int32")
    B = te.placeholder((n,), name="B", dtype="int32")

    B = te.extern(
        A.shape,
        [A],
        lambda ins, outs: do_compute(ins[0], outs[0], n),
        dtype="int32",
    )
    s = te.create_schedule(B.op)

    # Expected failure occurs at build step.
    func = tvm.build(s, [A, B], target)

    n = 32
    a_np = np.arange(n).astype(dtype=A.dtype)
    b_np = np.zeros((n,), dtype="int32")
    a = tvm.nd.array(a_np, dev)
    b = tvm.nd.array(b_np, dev)
    func(a, b)
    tvm.testing.assert_allclose(b.numpy(), a_np)


class TestVectorizedIndices:
    load_type, store_type = tvm.testing.parameters(
        # Load N values, write to N locations.
        # Vectorized copy.
        ("ramp", "ramp"),
        # Load 1 value, write to N locations.
        # Scalar load, vectorized store.
        #
        # Most TVM operations (e.g. schedule[tensor].vectorize(axis)) have
        # the broadcast outside of the index, but it is semantically okay
        # for the broadcast to be inside the index, and it shows up with
        # some optimizations.
        ("broadcast", "ramp"),
        # Load 1 values, write to 1 location.
        # Broadcasting on both sides should be equivalent to a scalar copy.
        ("broadcast", "broadcast"),
        # Loads N values, write to 1 location.
        # Disabled as it would have unclear semantics.
        # ("ramp","broadcoast"),
    )
    indirect_indices = tvm.testing.parameter(True, False, ids=["reorder", "no_reorder"])

    @tvm.testing.fixture
    def ref_data(self, load_type, store_type, indirect_indices):
        n = 4

        index_map = {
            "ramp": np.arange(n),
            "broadcast": np.zeros(n, dtype="int32"),
        }

        a_np = np.random.randint(np.iinfo("int32").max, size=n).astype("int32")
        b_np = np.zeros(shape=n, dtype=a_np.dtype)
        reorder_np = np.arange(n, dtype="int32")[::-1]

        load_index = index_map[load_type]
        store_index = index_map[store_type]

        if indirect_indices:
            load_index = reorder_np[load_index]

        b_np[store_index] = a_np[load_index]

        return a_np, reorder_np, b_np

    @tvm.testing.fixture
    def mod(self, target, load_type, store_type, indirect_indices):
        target = tvm.target.Target(target)

        n = 4
        dtype = "int32"
        A = te.placeholder((n,), dtype=dtype, name="A")
        R = te.placeholder((n,), dtype=dtype, name="R")

        def do_compute(ins, outs):
            ib = tvm.tir.ir_builder.create()
            A, R = map(ib.buffer_ptr, ins)
            B = ib.buffer_ptr(outs[0])

            if "gpu" in target.keys:
                ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0)

            index_map = {
                "ramp": tvm.tir.Ramp(0, 1, 4),
                "broadcast": tvm.tir.Broadcast(0, 4),
            }

            load_index = index_map[load_type]
            store_index = index_map[store_type]

            if indirect_indices:
                load_index = tvm.tir.expr.Load("int32x4", R, load_index)

            transfer = tvm.tir.expr.Load("int32x4", A, load_index)
            ib.emit(tvm.tir.stmt.Store(B, transfer, store_index))

            return ib.get()

        B = te.extern(A.shape, [A, R], do_compute, dtype="int32")
        s = te.create_schedule(B.op)

        return tvm.lower(s, [A, R, B])

    def test_ramp_broadcast_index(self, target, dev, mod, ref_data):
        f = tvm.build(mod, target=target)

        a_np, reorder_np, b_np = ref_data
        a = tvm.nd.array(a_np, dev)
        r = tvm.nd.array(reorder_np, dev)
        b = tvm.nd.array(np.zeros(shape=b_np.shape, dtype="int32"), dev)
        f(a, r, b)
        tvm.testing.assert_allclose(b.numpy(), b_np)


@tvm.testing.parametrize_targets("vulkan -max_shared_memory_per_block=16384")
def test_shared_mem_alloc(target, dev):
    alloc_nbytes = 16384 * 2

    def do_compute(ins, outs):
        ib = tvm.tir.ir_builder.create()
        out = ib.buffer_ptr(outs[0])

        ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0)

        array = ib.allocate("int32", (alloc_nbytes,), name="array", scope="shared")
        array[0] = 0
        out[0] = array[0]

        return ib.get()

    Out = te.extern(
        shape=(1,),
        inputs=[],
        fcompute=do_compute,
        dtype="int32",
    )
    s = te.create_schedule(Out.op)

    # Codegen should raise error when allocating more memory than the
    # target supports.
    with pytest.raises(tvm.TVMError):
        tvm.build(s, [Out], target)


if __name__ == "__main__":
    import sys

    sys.exit(pytest.main([__file__] + sys.argv[1:]))