import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon.model_zoo import vision
import numpy as np


class Image2Vec:
    '''
    Encapsulates all the logic to transform a Pillo image file to a vector
    representation based on the used model.
    '''
    def __init__(self):
        self._ctx = mx.cpu()
        self._net = vision.resnet18_v2(pretrained=True, ctx=self._ctx).features
        self.MEAN_IMAGE = mx.nd.array([0.485, 0.456, 0.406])
        self.STD_IMAGE = mx.nd.array([0.229, 0.224, 0.225])

    def preprocess_image(self, image):
        '''
        Preprocess an input Pillow image object.
        '''
        image_nd = self.correct_channel(nd.array(image))
        target_shape = (224, 244)
        resized = mx.image.resize_short(image_nd,
                                        target_shape[0]).astype('float32')
        cropped, crop_info = mx.image.center_crop(resized, target_shape)
        cropped /= 255.
        normalized = mx.image.color_normalize(cropped,
                                              mean=self.MEAN_IMAGE,
                                              std=self.STD_IMAGE)
        transposed = nd.transpose(normalized, (2, 0, 1))
        return transposed

    def correct_channel(self, image):
        if (len(image.shape) == 2):
            # Correct one channel (black-write) image to three channel (RGB) image by stacking
            image = nd.stack(image, image, image, axis=2)

        assert len(image.shape) == 3
        assert image.shape[2] == 3
        return image

    def to_vector(self, image):
        '''
        Vectorize an input Pillow image object.
        '''
        image_t = self.preprocess_image(image)
        output = self._net(image_t.expand_dims(axis=0).as_in_context(self._ctx))
        return output.asnumpy().reshape(-1, )