import numpy as np
from numba import jit
from collections import deque
import itertools
import os
import os.path as osp
import time
import torch
import cv2
import torch.nn.functional as F

from models.model import create_model, load_model
from models.decode import mot_decode
from tracking_utils.utils import *
from tracking_utils.log import logger
from tracking_utils.kalman_filter import KalmanFilter
from models import *
from tracker import matching
from .basetrack import BaseTrack, TrackState
from utils.post_process import ctdet_post_process
from utils.image import get_affine_transform
from models.utils import _tranpose_and_gather_feat

from datetime import datetime

import logging

from numba import jit
#logging.basicConfig(filename='demo.log',level=logging.INFO)

now_str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
#logging.basicConfig(filename='./multitracker-{}.log'.format(now_str),
#                level=logging.DEBUG)

class STrack(BaseTrack):
    shared_kalman = KalmanFilter()
    def __init__(self, tlwh, score, temp_feat, buffer_size=30):

        # wait activate
        self._tlwh = np.asarray(tlwh, dtype=np.float)
        self.kalman_filter = None
        self.mean, self.covariance = None, None
        self.is_activated = False

        self.score = score
        self.tracklet_len = 0

        self.smooth_feat = None
        self.update_features(temp_feat)
        self.features = deque([], maxlen=buffer_size)
        self.alpha = 0.9

    def update_features(self, feat):
        feat /= np.linalg.norm(feat)
        self.curr_feat = feat
        if self.smooth_feat is None:
            self.smooth_feat = feat
        else:
            self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
        self.features.append(feat)
        self.smooth_feat /= np.linalg.norm(self.smooth_feat)

    def predict(self):
        mean_state = self.mean.copy()
        if self.state != TrackState.Tracked:
            mean_state[7] = 0
        self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)

    @staticmethod
    def multi_predict(stracks):
        if len(stracks) > 0:
            multi_mean = np.asarray([st.mean.copy() for st in stracks])
            multi_covariance = np.asarray([st.covariance for st in stracks])
            for i, st in enumerate(stracks):
                if st.state != TrackState.Tracked:
                    multi_mean[i][7] = 0
            multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
            for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
                stracks[i].mean = mean
                stracks[i].covariance = cov

    def activate(self, kalman_filter, frame_id):
        """Start a new tracklet"""
        self.kalman_filter = kalman_filter
        self.track_id = self.next_id()
        self.mean, self.covariance = self.kalman_filter.initiate(self.tlwh_to_xyah(self._tlwh))

        self.tracklet_len = 0
        self.state = TrackState.Tracked
        if frame_id == 1:
            self.is_activated = True
        #self.is_activated = True
        self.frame_id = frame_id
        self.start_frame = frame_id

    def re_activate(self, new_track, frame_id, new_id=False):
        self.mean, self.covariance = self.kalman_filter.update(
            self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
        )

        self.update_features(new_track.curr_feat)
        self.tracklet_len = 0
        self.state = TrackState.Tracked
        self.is_activated = True
        self.frame_id = frame_id
        if new_id:
            self.track_id = self.next_id()

    def update(self, new_track, frame_id, update_feature=True):
        """
        Update a matched track
        :type new_track: STrack
        :type frame_id: int
        :type update_feature: bool
        :return:
        """
        self.frame_id = frame_id
        self.tracklet_len += 1

        new_tlwh = new_track.tlwh
        self.mean, self.covariance = self.kalman_filter.update(
            self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh))
        self.state = TrackState.Tracked
        self.is_activated = True

        self.score = new_track.score
        if update_feature:
            self.update_features(new_track.curr_feat)

    @property
    # @jit(nopython=True)
    def tlwh(self):
        """Get current position in bounding box format `(top left x, top left y,
                width, height)`.
        """
        if self.mean is None:
            return self._tlwh.copy()
        ret = self.mean[:4].copy()
        ret[2] *= ret[3]
        ret[:2] -= ret[2:] / 2

        return ret

    @property
    # @jit(nopython=True)
    def tlbr(self):
        """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
        `(top left, bottom right)`.
        """
        ret = self.tlwh.copy()
        ret[2:] += ret[:2]
        return ret

    @staticmethod
    @jit(nopython=True)
    def tlwh_to_xyah(tlwh):
        """Convert bounding box to format `(center x, center y, aspect ratio,
        height)`, where the aspect ratio is `width / height`.
        """
        ret = np.asarray(tlwh).copy()
        ret[:2] += ret[2:] / 2
        ret[2] /= ret[3]
        return ret

    def to_xyah(self):
        return self.tlwh_to_xyah(self.tlwh)

    @staticmethod
    @jit(nopython=True)
    def tlbr_to_tlwh(tlbr):
        ret = np.asarray(tlbr).copy()
        ret[2:] -= ret[:2]
        return ret

    @staticmethod
    @jit(nopython=True)
    def tlwh_to_tlbr(tlwh):
        ret = np.asarray(tlwh).copy()
        ret[2:] += ret[:2]
        return ret

    def __repr__(self):
        return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)

class FrameTracker(object):
    def __init__(self):
        self.tracked_stracks = []
        self.lost_stracks = []
        self.removed_stracks = []
        self.kalman_filter = KalmanFilter()

class ModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(ModelWrapper, self).__init__()
        self.model = model
    def forward(self, x):
        return [self.model(x)[0]['hm'], self.model(x)[0]['wh'], self.model(x)[0]['id'], self.model(x)[0]['reg']]

class JDETracker(object):
    def __init__(self, opt, batch_size=1):
        self.opt = opt
        if opt.gpus[0] >= 0:
            opt.device = torch.device('cuda')
        else:
            opt.device = torch.device('cpu')
        print('Creating model...')
               
        self.model = create_model(opt.arch, opt.heads, opt.head_conv, pretrained=opt.pretrained)
        self.model = load_model(self.model, opt.load_model)
        self.model = self.model.to(opt.device)
        self.model.eval()
        
        self.batch_size = batch_size
        self.frame_trackers = [ FrameTracker() for _ in range(batch_size)]

        self.frame_w = opt.frame_w
        self.frame_h = opt.frame_h
        self.inp_w = opt.inp_w
        self.inp_h = opt.inp_h

        self.new_shape, self.top, self.bottom, self.left, self.right = self.cal_new_shape()

        self.cal_meta(self.frame_w, self.frame_h, self.inp_w, self.inp_h)

        self.frame_id = 0
        self.det_thresh = opt.conf_thres
        self.buffer_size = int(opt.frame_rate / 30.0 * opt.track_buffer)
        self.max_time_lost = self.buffer_size
        self.max_per_image = opt.K
        self.mean = np.array(opt.mean, dtype=np.float32).reshape(1, 1, 3)
        self.std = np.array(opt.std, dtype=np.float32).reshape(1, 1, 3)
        
        BaseTrack._count = 0

    def cal_new_shape(self):
        ratio = min(float(self.inp_h) / self.frame_h, float(self.inp_w) / self.frame_w)
        new_shape = (round(self.frame_w * ratio), round(self.frame_h * ratio)) 
        
        dw = (self.inp_w - new_shape[0]) / 2
        dh = (self.inp_h - new_shape[1]) / 2
        top, bottom = round(dh - 0.1), round(dh + 0.1)
        left, right = round(dw - 0.1), round(dw + 0.1)
        
        return new_shape, top, bottom, left, right

    
    def cal_meta(self, frame_w, frame_h, inp_w, inp_h):
        c = np.array([frame_w / 2., frame_h / 2.], dtype=np.float32)
        s = max(float(inp_w) / float(inp_h) * frame_h, frame_w) * 1.0
        self.meta = {'c': c, 's': s,
                'out_height': inp_h // self.opt.down_ratio,
                'out_width': inp_w // self.opt.down_ratio}
    
    def post_process(self, dets, meta):
        dets = dets.detach().cpu().numpy()
        dets = dets.reshape(1, -1, dets.shape[2])
        dets = ctdet_post_process(
            dets.copy(), [meta['c']], [meta['s']],
            meta['out_height'], meta['out_width'], self.opt.num_classes)
        
        dets = dets[0]
        
        for j in range(1, self.opt.num_classes+1):
            dets[j] = np.array(dets[j], dtype=np.float32).reshape(-1, 5)
        
        return dets

    def merge_outputs(self, detections):
        results = {}
        for j in range(1, self.opt.num_classes + 1):
            results[j] = np.concatenate(
                [detection[j] for detection in detections], axis=0).astype(np.float32)

        scores = np.hstack(
            [results[j][:, 4] for j in range(1, self.opt.num_classes + 1)])
        if len(scores) > self.max_per_image:
            kth = len(scores) - self.max_per_image
            thresh = np.partition(scores, kth)[kth]
            for j in range(1, self.opt.num_classes + 1):
                keep_inds = (results[j][:, 4] >= thresh)
                results[j] = results[j][keep_inds]
        return results

    def associate_tracker(self, dets, meta, id_feature, stream_id):
        activated_starcks = []
        refind_stracks = []
        lost_stracks = []
        removed_stracks = []
        tracked_stracks = [] 
        
        frame_tracker = self.frame_trackers[stream_id]
        
        dets = self.post_process(dets, meta)
        dets = self.merge_outputs([dets])[1]
        remain_inds = dets[:, 4] > self.opt.conf_thres
        
        dets = dets[remain_inds]
        id_feature = id_feature[remain_inds]
        
        if len(dets) > 0:
            '''Detections'''
            detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
                          (tlbrs, f) in zip(dets[:, :5], id_feature)]
        else:
            detections = []
        
        ''' Add newly detected tracklets to tracked_stracks'''
        unconfirmed = []

        for track in frame_tracker.tracked_stracks:
            if not track.is_activated:
                unconfirmed.append(track)
            else:
                tracked_stracks.append(track)
        
        ''' Step 2: First association, with embedding'''
        strack_pool = joint_stracks(tracked_stracks, frame_tracker.lost_stracks)
        # Predict the current location with KF
        #for strack in strack_pool:
            #strack.predict()
        STrack.multi_predict(strack_pool)
        dists = matching.embedding_distance(strack_pool, detections)
        #dists = matching.iou_distance(strack_pool, detections)
        dists = matching.fuse_motion(frame_tracker.kalman_filter, dists, strack_pool, detections)
        matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.8)
        
        for itracked, idet in matches:
            track = strack_pool[itracked]

            det = detections[idet]

            if track.state == TrackState.Tracked:
                track.update(detections[idet], self.frame_id)
                activated_starcks.append(track)
            else:
                # Frame_id is same for each video stream
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)
        
        ''' Step 3: Second association, with IOU'''
        detections = [detections[i] for i in u_detection]
        r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
        dists = matching.iou_distance(r_tracked_stracks, detections)
        matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
        
        for itracked, idet in matches:
            track = r_tracked_stracks[itracked]
            det = detections[idet]

            if track.state == TrackState.Tracked:
                track.update(det, self.frame_id)
                activated_starcks.append(track)
            else:
                track.re_activate(det, self.frame_id, new_id=False)
                refind_stracks.append(track)
        
        for it in u_track:
            track = r_tracked_stracks[it]
            if not track.state == TrackState.Lost:
                track.mark_lost()
                lost_stracks.append(track)
        
        '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
        detections = [detections[i] for i in u_detection]
        dists = matching.iou_distance(unconfirmed, detections)
        matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
        for itracked, idet in matches:
            unconfirmed[itracked].update(detections[idet], self.frame_id)
            activated_starcks.append(unconfirmed[itracked])
        
        for it in u_unconfirmed:
            track = unconfirmed[it]
            track.mark_removed()
            removed_stracks.append(track)
        
        """ Step 4: Init new stracks"""
        for inew in u_detection:
            track = detections[inew]
            if track.score < self.det_thresh:
                continue
            
            track.activate(frame_tracker.kalman_filter, self.frame_id)
            activated_starcks.append(track)

        """ Step 5: Update state"""
        for track in frame_tracker.lost_stracks:
            if self.frame_id - track.end_frame > self.max_time_lost:
                track.mark_removed()
                removed_stracks.append(track)
        
        frame_tracker.tracked_stracks = [t for t in frame_tracker.tracked_stracks if t.state == TrackState.Tracked]
        frame_tracker.tracked_stracks = joint_stracks(frame_tracker.tracked_stracks, activated_starcks)
        frame_tracker.tracked_stracks = joint_stracks(frame_tracker.tracked_stracks, refind_stracks)
        
        frame_tracker.lost_stracks = sub_stracks(frame_tracker.lost_stracks,
                frame_tracker.tracked_stracks)
        frame_tracker.lost_stracks.extend(lost_stracks)
        frame_tracker.lost_stracks = sub_stracks(frame_tracker.lost_stracks,
                frame_tracker.removed_stracks)

        frame_tracker.removed_stracks.extend(removed_stracks)
        frame_tracker.tracked_stracks, frame_tracker.lost_stracks = remove_duplicate_stracks(
                frame_tracker.tracked_stracks, frame_tracker.lost_stracks)
        # get scores of lost tracks
        self.frame_trackers[stream_id] = frame_tracker
        
        logging.info('tracker_stracker: {}, removed_stracks: {}, lost_stracks: {}'.format(
            len(frame_tracker.tracked_stracks),
            len(frame_tracker.removed_stracks),
            len(frame_tracker.lost_stracks)))

        output_stracks = [track for track in frame_tracker.tracked_stracks if track.is_activated]
        
        logging.debug('===========Frame {}=========='.format(self.frame_id))
        logging.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
        logging.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
        logging.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
        logging.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))

        return output_stracks
    
    def letterbox(self, img, color=(127.5, 127.5, 127.5)):
        img = cv2.resize(img, self.new_shape, interpolation=cv2.INTER_AREA)
        img = cv2.copyMakeBorder(img, self.top, self.bottom, self.left,
                self.right, cv2.BORDER_CONSTANT, value=color)
        return img

    def preprocessing(self, imgs):
        s0 = time.time()
        blobs = []

        for img in imgs:
            img = self.letterbox(img)
            img = img[:, :, ::-1].transpose(2, 0, 1)
            img = np.ascontiguousarray(img, dtype=np.float32)
            img /= 255.0
            blob = torch.from_numpy(img).cuda().unsqueeze(0)
            blobs.append(blob)
        torch_blobs = torch.cat(tuple(blobs), 0)

        return torch_blobs
    
    def update(self, imgs):
        im_blobs = self.preprocessing(imgs)
        self.frame_id += 1

        ''' Stepo 1: Network forward, get detections & embeddings'''
        with torch.no_grad():
            output = self.model(im_blobs)[-1]
            dets_list = []
            id_feature_list = []

            for i in range(self.batch_size):
                _hm = output['hm'][i,:,:,:].unsqueeze(0)
                _wh = output['wh'][i,:,:,:].unsqueeze(0)
                _id = output['id'][i,:,:].unsqueeze(0)
                _reg = output['reg'][i,:,:,:].unsqueeze(0)
                
                hm = _hm.sigmoid_()
                id_feature = _id

                id_feature = F.normalize(id_feature, dim=1)
                reg = _reg if self.opt.reg_offset else None
                dets, inds = mot_decode(hm, _wh, reg=reg, ltrb=self.opt.ltrb, K=self.opt.K)
                
                id_feature = _tranpose_and_gather_feat(id_feature, inds)

                id_feature = id_feature.squeeze(0)
                id_feature = id_feature.cpu().numpy()
                dets_list.append(dets)
                id_feature_list.append(id_feature)
        
        output_stracks_list = []
        for i in range(self.batch_size):
            output_stracks = self.associate_tracker(dets_list[i],
                    self.meta,
                    id_feature_list[i],i)
            
            tracker_dict = {}
            for t in output_stracks:
                tid = t.track_id
                tlwh = t.tlwh

                # Filter people who doesn't stand
                vertical = tlwh[2] / tlwh[3] > 1.6
                if tlwh[2] * tlwh[3] > self.opt.min_box_area and not vertical:
                    tracker_dict[tid] = list(tlwh)
            
            output_stracks_list.append(tracker_dict)
        return output_stracks_list

def joint_stracks(tlista, tlistb):
    exists = {}
    res = []
    for t in tlista:
        exists[t.track_id] = 1
        res.append(t)
    for t in tlistb:
        tid = t.track_id
        if not exists.get(tid, 0):
            exists[tid] = 1
            res.append(t)
    return res

def sub_stracks(tlista, tlistb):
    stracks = {}
    for t in tlista:
        stracks[t.track_id] = t
    for t in tlistb:
        tid = t.track_id
        if stracks.get(tid, 0):
            del stracks[tid]
    return list(stracks.values())


def remove_duplicate_stracks(stracksa, stracksb):
    pdist = matching.iou_distance(stracksa, stracksb)
    pairs = np.where(pdist < 0.15)
    dupa, dupb = list(), list()
    for p, q in zip(*pairs):
        timep = stracksa[p].frame_id - stracksa[p].start_frame
        timeq = stracksb[q].frame_id - stracksb[q].start_frame
        if timep > timeq:
            dupb.append(q)
        else:
            dupa.append(p)
    resa = [t for i, t in enumerate(stracksa) if not i in dupa]
    resb = [t for i, t in enumerate(stracksb) if not i in dupb]
    return resa, resb