# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import tensorflow as tf
from MaskRCNN.performance import print_runtime_tensor

from tensorpack.models import Conv2D, Conv2DTranspose, layer_register
from tensorpack.tfutils.argscope import argscope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope, under_name_scope
from tensorpack.tfutils.summary import add_moving_summary, add_tensor_summary

from model.backbone import GroupNorm
from config import config as cfg
#from utils.mixed_precision import mixed_precision_scope

@under_name_scope()
def maskrcnn_loss(mask_logits, fg_labels, fg_target_masks):
    """
    Args:
        mask_logits: Num_fg_boxes x  num_category x H_roi x W_roi 
        fg_labels: 1-D Num_fg_boxes, in 1~#class, int64
        fg_target_masks: Num_fg_boxes x H_roi x W_roi, float32
    Returns: mask loss
    """
    num_fg = tf.size(fg_labels, out_type=tf.int64) # scalar Num_fg_boxes
    indices = tf.stack([tf.range(num_fg), fg_labels - 1], axis=1)  # Num_fg_boxes x 2
    mask_logits = tf.gather_nd(mask_logits, indices)  # Num_fg_boxes x H_roi x W_roi
    mask_probs = tf.sigmoid(mask_logits)

    # add some training visualizations to tensorboard
    with tf.name_scope('mask_viz'):
        viz = tf.concat([fg_target_masks, mask_probs], axis=1)
        viz = tf.expand_dims(viz, 3)
        viz = tf.cast(viz * 255, tf.uint8, name='viz')
        tf.summary.image('mask_truth|pred', viz, max_outputs=10)

    loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=fg_target_masks, logits=mask_logits)
    loss = tf.math.reduce_mean(loss, name='maskrcnn_loss')

    # Calculate the accuracy
    pred_label = mask_probs > 0.5
    truth_label = fg_target_masks > 0.5
    accuracy = tf.math.reduce_mean(tf.cast(tf.equal(pred_label, truth_label), tf.float32), name='accuracy')
    pos_accuracy = tf.math.logical_and(tf.equal(pred_label, truth_label), tf.equal(truth_label, True))
    pos_accuracy = tf.math.reduce_mean(tf.cast(pos_accuracy, tf.float32), name='pos_accuracy')
    fg_pixel_ratio = tf.math.reduce_mean(tf.cast(truth_label, tf.float32), name='fg_pixel_ratio')

    add_moving_summary(loss, accuracy, fg_pixel_ratio, pos_accuracy)
    return loss


@layer_register(log_shape=True)
@auto_reuse_variable_scope
def maskrcnn_upXconv_head(feature, num_category, seed_gen, num_convs, norm=None):
    """
    Args:
        feature: roi feature maps, Num_boxes x  H_roi x W_roi x NumChannel
        num_category(int): Number of total classes
        num_convs (int): number of convolution layers
        norm (str or None): either None or 'GN'

    Returns:
        mask_logits: Num_boxes x num_category X (2 * H_roi) x (2 * W_roi) 
    """
    assert norm in [None, 'GN'], norm
    l = feature
    with argscope([Conv2D, Conv2DTranspose], data_format='channels_first' if cfg.TRAIN.MASK_NCHW else 'channels_last',
                  kernel_initializer=tf.keras.initializers.VarianceScaling(
                      scale=2.0, mode='fan_out', 
                      distribution='untruncated_normal',
                      seed=seed_gen.next())):
        # c2's MSRAFill is fan_out
        for k in range(num_convs):
            l = Conv2D('fcn{}'.format(k), l, cfg.MRCNN.HEAD_DIM, 3, activation=tf.nn.relu, seed=seed_gen.next())
            if norm is not None:
                l = GroupNorm('gn{}'.format(k), l)
        l = Conv2DTranspose('deconv', l, cfg.MRCNN.HEAD_DIM, 2, strides=2, activation=tf.nn.relu, seed=seed_gen.next()) # 2x upsampling
        l = Conv2D('conv', l, num_category, 1, seed=seed_gen.next())
    if not cfg.TRAIN.MASK_NCHW:
        l = tf.transpose(l, [0, 3, 1, 2])
    return l

# Without Group Norm
def maskrcnn_up4conv_head(*args, **kwargs):
    return maskrcnn_upXconv_head(*args, num_convs=4, **kwargs)

# With Group Norm
def maskrcnn_up4conv_gn_head(*args, **kwargs):
    return maskrcnn_upXconv_head(*args, num_convs=4, norm='GN', **kwargs)