# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,too-many-locals,too-many-arguments,missing-module-docstring

import tvm
from tvm import relay
from tvm.relay import transform


def run_opt_pass(expr, opt_pass):
    "runs the opt_pass on the expr of a function the function"
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = tvm.relay.transform.InferType()(mod)
    mod = opt_pass(mod)
    return mod["main"]


def test_combine_parallel_batch_matmul():
    """Simple testcase."""

    def before(x, w1, w2, w3):
        args = [x, w1, w2, w3]
        y1 = relay.nn.batch_matmul(x, w1)
        y2 = relay.nn.batch_matmul(x, w2)
        y3 = relay.nn.batch_matmul(x, w3)
        y = relay.Tuple((y1, y2, y3))
        return relay.Function(args, y)

    def expected(x, w1, w2, w3):
        # use a fixed order of args so alpha equal check can pass
        s1 = w1.type_annotation.shape[1]
        s2 = w2.type_annotation.shape[1]
        s3 = w3.type_annotation.shape[1]
        args = [x, w1, w2, w3]
        w = relay.concatenate((w1, w2, w3), axis=1)
        y = relay.nn.batch_matmul(x, w)
        y1 = relay.strided_slice(
            y, begin=[0, 0, 0], end=[-1, -1, s1], strides=[1, 1, 1], slice_mode="size"
        )
        y2 = relay.strided_slice(
            y, begin=[0, 0, s1], end=[-1, -1, s2], strides=[1, 1, 1], slice_mode="size"
        )
        y3 = relay.strided_slice(
            y, begin=[0, 0, s1 + s2], end=[-1, -1, s3], strides=[1, 1, 1], slice_mode="size"
        )
        y = relay.Tuple((y1, y2, y3))
        return relay.Function(args, y)

    def check(b, i, j, k):
        x = relay.var("x", shape=(b, i, k))
        w1 = relay.var("w1", shape=(b, j, k))
        w2 = relay.var("w2", shape=(b, j, k))
        w3 = relay.var("w3", shape=(b, j, k))

        y_before = before(x, w1, w2, w3)
        y = run_opt_pass(y_before, transform.CombineParallelBatchMatmul(min_num_branches=2))
        y_expected = expected(x, w1, w2, w3)
        y_expected = run_opt_pass(y_expected, transform.InferType())
        tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)

    check(2, 3, 5, 4)
    check(1, 100, 200, 300)


def test_combine_parallel_batch_matmul_biasadd():
    """Simple testcase with bias"""

    def before(x, w1, w2, w3, b1, b2, b3):
        args = [x, w1, w2, w3, b1, b2, b3]
        y1 = relay.nn.batch_matmul(x, w1)
        y2 = relay.nn.batch_matmul(x, w2)
        y3 = relay.nn.batch_matmul(x, w3)
        y1 = relay.add(y1, b1)
        y2 = relay.add(y2, b2)
        y3 = relay.add(y3, b3)
        y = relay.Tuple((y1, y2, y3))
        return relay.Function(args, y)

    def expected(x, w1, w2, w3, b1, b2, b3):
        # use a fixed order of args so alpha equal check can pass
        s1 = w1.type_annotation.shape[1]
        s2 = w2.type_annotation.shape[1]
        s3 = w3.type_annotation.shape[1]
        args = [x, w1, w2, w3, b1, b2, b3]
        w = relay.concatenate((w1, w2, w3), axis=1)
        b = relay.concatenate((b1, b2, b3), axis=-1)
        y = relay.nn.batch_matmul(x, w)
        y = relay.add(y, b)
        y1 = relay.strided_slice(
            y, begin=[0, 0, 0], end=[-1, -1, s1], strides=[1, 1, 1], slice_mode="size"
        )
        y2 = relay.strided_slice(
            y, begin=[0, 0, s1], end=[-1, -1, s2], strides=[1, 1, 1], slice_mode="size"
        )
        y3 = relay.strided_slice(
            y, begin=[0, 0, s1 + s2], end=[-1, -1, s3], strides=[1, 1, 1], slice_mode="size"
        )
        y = relay.Tuple((y1, y2, y3))
        return relay.Function(args, y)

    def check(b, i, j, k):
        x = relay.var("x", shape=(b, i, k))
        w1 = relay.var("w1", shape=(b, j, k))
        w2 = relay.var("w2", shape=(b, j, k))
        w3 = relay.var("w3", shape=(b, j, k))
        b1 = relay.var("b1", shape=(j,))
        b2 = relay.var("b2", shape=(j,))
        b3 = relay.var("b3", shape=(j,))

        y_before = before(x, w1, w2, w3, b1, b2, b3)
        y = run_opt_pass(y_before, transform.CombineParallelBatchMatmul(min_num_branches=2))
        y_expected = expected(x, w1, w2, w3, b1, b2, b3)
        y_expected = run_opt_pass(y_expected, transform.InferType())
        tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)

    check(2, 3, 5, 4)
    check(1, 100, 200, 300)


if __name__ == "__main__":
    test_combine_parallel_batch_matmul()
    test_combine_parallel_batch_matmul_biasadd()