#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (c) 2021, Amazon Web Services. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from sagemakercv.core import training_ops, GenericRoIExtractor, TargetEncoder, BoxDetector, RandomSampler
from ..builder import HEADS, build_box_head, build_mask_head

class StandardRoIHead(tf.keras.Model):
    """Simplest base roi head including one bbox head and one mask head."""
    
    def __init__(self,
                 bbox_head,
                 bbox_roi_extractor,
                 bbox_sampler,
                 box_encoder,
                 inference_detector,
                 mask_head=None,
                 mask_roi_extractor=None,
                 name="StandardRoIHead",
                 trainable=True,
                 *args,
                 **kwargs,
                 ):
        super(StandardRoIHead, self).__init__(name=name, trainable=trainable, *args, **kwargs)
        self.bbox_head = bbox_head
        self.bbox_roi_extractor = bbox_roi_extractor
        self.bbox_sampler = bbox_sampler
        self.box_encoder = box_encoder
        self.mask_head = mask_head
        self.mask_roi_extractor = mask_roi_extractor if mask_roi_extractor is not None \
                                  else self.bbox_roi_extractor
        self.inference_detector = inference_detector
    
    @property
    def with_mask(self):
        """bool: whether the RoI head contains a `mask_head`"""
        return hasattr(self, 'mask_head') and self.mask_head is not None
    
    def call(self,
             fpn_feats,
             img_info,
             proposals,
             gt_bboxes=None,
             gt_labels=None,
             gt_masks=None,
             training=True):
        model_outputs=dict()
        
        if training:
            box_targets, class_targets, rpn_box_rois, proposal_to_label_map = self.bbox_sampler(proposals,
                                                                                                gt_bboxes, 
                                                                                                gt_labels)
        else:
            rpn_box_rois = proposals
            
        box_roi_features = self.bbox_roi_extractor(fpn_feats, rpn_box_rois)
        
        class_outputs, box_outputs, _ = self.bbox_head(inputs=box_roi_features)
        
        if not training:
            model_outputs.update(self.inference_detector(class_outputs, box_outputs, rpn_box_rois, img_info))
            model_outputs.update({'class_outputs': tf.nn.softmax(class_outputs),
                                  'box_outputs': box_outputs,
                                  'anchor_boxes': rpn_box_rois})
        else:
            if self.bbox_head.loss.box_loss_type not in ["giou", "ciou"]:
                encoded_box_targets = self.box_encoder(boxes=rpn_box_rois,
                                                       gt_bboxes=box_targets,
                                                       gt_labels=class_targets)
            model_outputs.update({
                'class_outputs': class_outputs,
                'box_outputs': box_outputs,
                'class_targets': class_targets,
                'box_targets': encoded_box_targets if self.bbox_head.loss.box_loss_type \
                                                      not in ["giou", "ciou"] \
                                                      else box_targets,
                'box_rois': rpn_box_rois,
            })
            total_loss, class_loss, box_loss = self.bbox_head.loss(model_outputs['class_outputs'],
                                                                   model_outputs['box_outputs'],
                                                                   model_outputs['class_targets'],
                                                                   model_outputs['box_targets'],
                                                                   model_outputs['box_rois'],
                                                                   img_info)
            model_outputs.update({
                'total_loss_bbox': total_loss,
                'class_loss': class_loss,
                'box_loss': box_loss
            })
        if not self.with_mask:
            return model_outputs
        if not training:
            return self.call_mask(model_outputs, fpn_feats, training=False)
        max_fg = int(self.bbox_sampler.batch_size_per_im * self.bbox_sampler.fg_fraction)
        return self.call_mask(model_outputs,
                              fpn_feats,
                              class_targets=class_targets,
                              box_targets=box_targets,
                              rpn_box_rois=rpn_box_rois,
                              proposal_to_label_map=proposal_to_label_map,
                              gt_masks=gt_masks,
                              max_fg=max_fg,
                              training=True)
    
    def call_mask(self, 
                  model_outputs,
                  fpn_feats,
                  class_targets=None,
                  box_targets=None,
                  rpn_box_rois=None,
                  proposal_to_label_map=None,
                  gt_masks=None,
                  max_fg=None,
                  training=True):
        if not training:
            selected_box_rois = model_outputs['detection_boxes']
            class_indices = model_outputs['detection_classes']
            class_indices = tf.cast(class_indices, dtype=tf.int32)
            
        else:
            selected_class_targets, selected_box_targets, \
            selected_box_rois, proposal_to_label_map = training_ops.select_fg_for_masks(
                class_targets=class_targets,
                box_targets=box_targets,
                boxes=rpn_box_rois,
                proposal_to_label_map=proposal_to_label_map,
                max_num_fg=max_fg
            )
            
            class_indices = selected_class_targets
            class_indices = tf.cast(selected_class_targets, dtype=tf.int32)
            
        mask_roi_features = self.mask_roi_extractor(
                fpn_feats,
                selected_box_rois,
            )
        
        mask_outputs = self.mask_head(inputs=mask_roi_features, class_indices=class_indices)
        
        if training:
            mask_targets = training_ops.get_mask_targets(
                fg_boxes=selected_box_rois,
                fg_proposal_to_label_map=proposal_to_label_map,
                fg_box_targets=selected_box_targets,
                mask_gt_labels=gt_masks,
                output_size=self.mask_head._mrcnn_resolution
            )

            model_outputs.update({
                'mask_outputs': mask_outputs,
                'mask_targets': mask_targets,
                'selected_class_targets': selected_class_targets,
            })
            
            mask_loss = self.mask_head.loss(model_outputs['mask_outputs'],
                                             model_outputs['mask_targets'],
                                             model_outputs['selected_class_targets'],)
            model_outputs.update({'mask_loss': mask_loss})

        else:
            model_outputs.update({
                'detection_masks': tf.nn.sigmoid(mask_outputs),
            })

        return model_outputs

@HEADS.register("StandardRoIHead")
def build_standard_roi_head(cfg):
    roi_head = StandardRoIHead
    bbox_head = build_box_head(cfg)
    bbox_roi_extractor = GenericRoIExtractor(cfg.MODEL.FRCNN.ROI_SIZE,
                                            cfg.MODEL.FRCNN.GPU_INFERENCE)
    bbox_sampler = RandomSampler(batch_size_per_im=cfg.MODEL.RCNN.BATCH_SIZE_PER_IMAGE,
                                 fg_fraction=cfg.MODEL.RCNN.FG_FRACTION, 
                                 fg_thresh=cfg.MODEL.RCNN.THRESH,     
                                 bg_thresh_hi=cfg.MODEL.RCNN.THRESH_HI, 
                                 bg_thresh_lo=cfg.MODEL.RCNN.THRESH_LO)
    box_encoder = TargetEncoder(bbox_reg_weights=cfg.MODEL.BBOX_REG_WEIGHTS)
    inference_detector = BoxDetector(use_batched_nms=cfg.MODEL.INFERENCE.USE_BATCHED_NMS,
                                     rpn_post_nms_topn=cfg.MODEL.INFERENCE.POST_NMS_TOPN,
                                     detections_per_image=cfg.MODEL.INFERENCE.DETECTIONS_PER_IMAGE,
                                     test_nms=cfg.MODEL.INFERENCE.DETECTOR_NMS,
                                     class_agnostic_box=cfg.MODEL.INFERENCE.CLASS_AGNOSTIC,
                                     bbox_reg_weights=cfg.MODEL.BBOX_REG_WEIGHTS)

    if cfg.MODEL.INCLUDE_MASK:
        mask_roi_extractor = GenericRoIExtractor(cfg.MODEL.MRCNN.ROI_SIZE,
                                                 cfg.MODEL.MRCNN.GPU_INFERENCE)
        mask_head = build_mask_head(cfg)
    else:
        mask_head = None
        mask_roi_extractor = None
    return roi_head(bbox_head=bbox_head,
                    bbox_roi_extractor=bbox_roi_extractor,
                    bbox_sampler=bbox_sampler,
                    box_encoder=box_encoder,
                    inference_detector=inference_detector,
                    mask_head=mask_head,
                    mask_roi_extractor=mask_roi_extractor)