#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import print_function

import os
import sys

curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, "../python"))
import argparse
import random
import time
import traceback

import cv2
import mxnet as mx

try:
    import multiprocessing
except ImportError:
    multiprocessing = None


def list_image(root, recursive, exts):
    i = 0
    if recursive:
        cat = {}
        for path, dirs, files in os.walk(root, followlinks=True):
            dirs.sort()
            files.sort()
            for fname in files:
                fpath = os.path.join(path, fname)
                suffix = os.path.splitext(fname)[1].lower()
                if os.path.isfile(fpath) and (suffix in exts):
                    if path not in cat:
                        cat[path] = len(cat)
                    yield (i, os.path.relpath(fpath, root), cat[path])
                    i += 1
        for k, v in sorted(cat.items(), key=lambda x: x[1]):
            print(os.path.relpath(k, root), v)
    else:
        for fname in sorted(os.listdir(root)):
            fpath = os.path.join(root, fname)
            suffix = os.path.splitext(fname)[1].lower()
            if os.path.isfile(fpath) and (suffix in exts):
                yield (i, os.path.relpath(fpath, root), 0)
                i += 1


def write_list(path_out, image_list):
    with open(path_out, "w") as fout:
        for i, item in enumerate(image_list):
            line = "%d\t" % item[0]
            for j in item[2:]:
                line += "%f\t" % j
            line += "%s\n" % item[1]
            fout.write(line)


def make_list(args):
    image_list = list_image(args.root, args.recursive, args.exts)
    image_list = list(image_list)
    if args.shuffle is True:
        random.seed(100)
        random.shuffle(image_list)
    N = len(image_list)
    chunk_size = (N + args.chunks - 1) // args.chunks
    for i in range(args.chunks):
        chunk = image_list[i * chunk_size : (i + 1) * chunk_size]
        if args.chunks > 1:
            str_chunk = "_%d" % i
        else:
            str_chunk = ""
        sep = int(chunk_size * args.train_ratio)
        sep_test = int(chunk_size * args.test_ratio)
        if args.train_ratio == 1.0:
            write_list(args.prefix + str_chunk + ".lst", chunk)
        else:
            if args.test_ratio:
                write_list(args.prefix + str_chunk + "_test.lst", chunk[:sep_test])
            if args.train_ratio + args.test_ratio < 1.0:
                write_list(args.prefix + str_chunk + "_val.lst", chunk[sep_test + sep :])
            write_list(args.prefix + str_chunk + "_train.lst", chunk[sep_test : sep_test + sep])


def read_list(path_in):
    with open(path_in) as fin:
        while True:
            line = fin.readline()
            if not line:
                break
            line = [i.strip() for i in line.strip().split("\t")]
            line_len = len(line)
            if line_len < 3:
                print(
                    "lst should at least has three parts, but only has %s parts for %s"
                    % (line_len, line)
                )
                continue
            try:
                item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]]
            except Exception as e:
                print("Parsing lst met error for %s, detail: %s" % (line, e))
                continue
            yield item


def image_encode(args, i, item, q_out):
    fullpath = os.path.join(args.root, item[1])

    if len(item) > 3 and args.pack_label:
        header = mx.recordio.IRHeader(0, item[2:], item[0], 0)
    else:
        header = mx.recordio.IRHeader(0, item[2], item[0], 0)

    if args.pass_through:
        try:
            with open(fullpath, "rb") as fin:
                img = fin.read()
            s = mx.recordio.pack(header, img)
            q_out.put((i, s, item))
        except Exception as e:
            traceback.print_exc()
            print("pack_img error:", item[1], e)
            q_out.put((i, None, item))
        return

    try:
        img = cv2.imread(fullpath, args.color)
    except:
        traceback.print_exc()
        print("imread error trying to load file: %s " % fullpath)
        q_out.put((i, None, item))
        return
    if img is None:
        print("imread read blank (None) image for file: %s" % fullpath)
        q_out.put((i, None, item))
        return
    if args.center_crop:
        if img.shape[0] > img.shape[1]:
            margin = (img.shape[0] - img.shape[1]) // 2
            img = img[margin : margin + img.shape[1], :]
        else:
            margin = (img.shape[1] - img.shape[0]) // 2
            img = img[:, margin : margin + img.shape[0]]
    if args.resize:
        if img.shape[0] > img.shape[1]:
            newsize = (args.resize, img.shape[0] * args.resize // img.shape[1])
        else:
            newsize = (img.shape[1] * args.resize // img.shape[0], args.resize)
        img = cv2.resize(img, newsize)

    try:
        s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding)
        q_out.put((i, s, item))
    except Exception as e:
        traceback.print_exc()
        print("pack_img error on file: %s" % fullpath, e)
        q_out.put((i, None, item))
        return


def read_worker(args, q_in, q_out):
    while True:
        deq = q_in.get()
        if deq is None:
            break
        i, item = deq
        image_encode(args, i, item, q_out)


def write_worker(q_out, fname, working_dir):
    pre_time = time.time()
    count = 0
    fname = os.path.basename(fname)
    fname_rec = os.path.splitext(fname)[0] + ".rec"
    fname_idx = os.path.splitext(fname)[0] + ".idx"
    record = mx.recordio.MXIndexedRecordIO(
        os.path.join(working_dir, fname_idx), os.path.join(working_dir, fname_rec), "w"
    )
    buf = {}
    more = True
    while more:
        deq = q_out.get()
        if deq is not None:
            i, s, item = deq
            buf[i] = (s, item)
        else:
            more = False
        while count in buf:
            s, item = buf[count]
            del buf[count]
            if s is not None:
                record.write_idx(item[0], s)

            if count % 1000 == 0:
                cur_time = time.time()
                print("time:", cur_time - pre_time, " count:", count)
                pre_time = cur_time
            count += 1


def parse_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description="Create an image list or \
        make a record database by reading from an image list",
    )
    parser.add_argument("prefix", help="prefix of input/output lst and rec files.")
    parser.add_argument("root", help="path to folder containing images.")

    cgroup = parser.add_argument_group("Options for creating image lists")
    cgroup.add_argument(
        "--list",
        action="store_true",
        help="If this is set im2rec will create image list(s) by traversing root folder\
        and output to <prefix>.lst.\
        Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec",
    )
    cgroup.add_argument(
        "--exts",
        nargs="+",
        default=[".jpeg", ".jpg", ".png"],
        help="list of acceptable image extensions.",
    )
    cgroup.add_argument("--chunks", type=int, default=1, help="number of chunks.")
    cgroup.add_argument(
        "--train-ratio", type=float, default=1.0, help="Ratio of images to use for training."
    )
    cgroup.add_argument(
        "--test-ratio", type=float, default=0, help="Ratio of images to use for testing."
    )
    cgroup.add_argument(
        "--recursive",
        action="store_true",
        help="If true recursively walk through subdirs and assign an unique label\
        to images in each folder. Otherwise only include images in the root folder\
        and give them label 0.",
    )
    cgroup.add_argument(
        "--no-shuffle",
        dest="shuffle",
        action="store_false",
        help="If this is passed, \
        im2rec will not randomize the image order in <prefix>.lst",
    )
    rgroup = parser.add_argument_group("Options for creating database")
    rgroup.add_argument(
        "--pass-through",
        action="store_true",
        help="whether to skip transformation and save image as is",
    )
    rgroup.add_argument(
        "--resize",
        type=int,
        default=0,
        help="resize the shorter edge of image to the newsize, original images will\
        be packed by default.",
    )
    rgroup.add_argument(
        "--center-crop",
        action="store_true",
        help="specify whether to crop the center image to make it rectangular.",
    )
    rgroup.add_argument(
        "--quality",
        type=int,
        default=95,
        help="JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9",
    )
    rgroup.add_argument(
        "--num-thread",
        type=int,
        default=1,
        help="number of thread to use for encoding. order of images will be different\
        from the input list if >1. the input list will be modified to match the\
        resulting order.",
    )
    rgroup.add_argument(
        "--color",
        type=int,
        default=1,
        choices=[-1, 0, 1],
        help="specify the color mode of the loaded image.\
        1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\
        0: Loads image in grayscale mode.\
        -1:Loads image as such including alpha channel.",
    )
    rgroup.add_argument(
        "--encoding",
        type=str,
        default=".jpg",
        choices=[".jpg", ".png"],
        help="specify the encoding of the images.",
    )
    rgroup.add_argument(
        "--pack-label",
        action="store_true",
        help="Whether to also pack multi dimensional label in the record file",
    )
    args = parser.parse_args()
    args.prefix = os.path.abspath(args.prefix)
    args.root = os.path.abspath(args.root)
    return args


if __name__ == "__main__":
    args = parse_args()
    if args.list:
        make_list(args)
    else:
        if os.path.isdir(args.prefix):
            working_dir = args.prefix
        else:
            working_dir = os.path.dirname(args.prefix)
        files = [
            os.path.join(working_dir, fname)
            for fname in os.listdir(working_dir)
            if os.path.isfile(os.path.join(working_dir, fname))
        ]
        count = 0
        for fname in files:
            if fname.startswith(args.prefix) and fname.endswith(".lst"):
                print("Creating .rec file from", fname, "in", working_dir)
                count += 1
                image_list = read_list(fname)
                # -- write_record -- #
                if args.num_thread > 1 and multiprocessing is not None:
                    q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)]
                    q_out = multiprocessing.Queue(1024)
                    read_process = [
                        multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out))
                        for i in range(args.num_thread)
                    ]
                    for p in read_process:
                        p.start()
                    write_process = multiprocessing.Process(
                        target=write_worker, args=(q_out, fname, working_dir)
                    )
                    write_process.start()

                    for i, item in enumerate(image_list):
                        q_in[i % len(q_in)].put((i, item))
                    for q in q_in:
                        q.put(None)
                    for p in read_process:
                        p.join()

                    q_out.put(None)
                    write_process.join()
                else:
                    print("multiprocessing not available, fall back to single threaded encoding")
                    try:
                        import Queue as queue
                    except ImportError:
                        import queue
                    q_out = queue.Queue()
                    fname = os.path.basename(fname)
                    fname_rec = os.path.splitext(fname)[0] + ".rec"
                    fname_idx = os.path.splitext(fname)[0] + ".idx"
                    record = mx.recordio.MXIndexedRecordIO(
                        os.path.join(working_dir, fname_idx),
                        os.path.join(working_dir, fname_rec),
                        "w",
                    )
                    cnt = 0
                    pre_time = time.time()
                    for i, item in enumerate(image_list):
                        image_encode(args, i, item, q_out)
                        if q_out.empty():
                            continue
                        _, s, _ = q_out.get()
                        record.write_idx(item[0], s)
                        if cnt % 1000 == 0:
                            cur_time = time.time()
                            print("time:", cur_time - pre_time, " count:", cnt)
                            pre_time = cur_time
                        cnt += 1
        if not count:
            print("Did not find and list file with prefix %s" % args.prefix)