#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import print_function import argparse import os import subprocess import sys curr_path = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(curr_path, "..")) from concat_db import ConcatDB from pascal_voc import PascalVoc def load_pascal(image_set, year, devkit_path, shuffle=False): """ wrapper function for loading pascal voc dataset Parameters: ---------- image_set : str train, trainval... year : str 2007, 2012 or combinations splitted by comma devkit_path : str root directory of dataset shuffle : bool whether to shuffle initial list Returns: ---------- Imdb """ image_set = [y.strip() for y in image_set.split(",")] assert image_set, "No image_set specified" year = [y.strip() for y in year.split(",")] assert year, "No year specified" # make sure (# sets == # years) if len(image_set) > 1 and len(year) == 1: year = year * len(image_set) if len(image_set) == 1 and len(year) > 1: image_set = image_set * len(year) assert len(image_set) == len(year), "Number of sets and year mismatch" imdbs = [] for s, y in zip(image_set, year): imdbs.append(PascalVoc(s, y, devkit_path, shuffle, is_train=True)) if len(imdbs) > 1: return ConcatDB(imdbs, shuffle) else: return imdbs[0] def parse_args(): parser = argparse.ArgumentParser(description="Prepare lists for dataset") parser.add_argument( "--dataset", dest="dataset", help="dataset to use", default="pascal", type=str ) parser.add_argument( "--year", dest="year", help="which year to use", default="2007,2012", type=str ) parser.add_argument( "--set", dest="set", help="train, val, trainval, test", default="trainval", type=str ) parser.add_argument( "--target", dest="target", help="output list file", default=os.path.join(curr_path, "..", "train.lst"), type=str, ) parser.add_argument( "--root", dest="root_path", help="dataset root path", default=os.path.join(curr_path, "..", "VOCdevkit"), type=str, ) parser.add_argument("--no-shuffle", dest="shuffle", help="shuffle list", action="store_false") parser.add_argument( "--num-thread", dest="num_thread", type=int, default=4, help="number of thread to use while runing im2rec.py", ) args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() if args.dataset == "pascal": db = load_pascal(args.set, args.year, args.root_path, args.shuffle) print("saving list to disk...") db.save_imglist(args.target, root=args.root_path) else: raise NotImplementedError("No implementation for dataset: " + args.dataset) print("List file {} generated...".format(args.target)) cmd_arguments = [ "python", os.path.join(curr_path, "im2rec.py"), os.path.abspath(args.target), os.path.abspath(args.root_path), "--pack-label", "--num-thread", str(args.num_thread), ] if not args.shuffle: cmd_arguments.append("--no-shuffle") subprocess.check_call(cmd_arguments) print("Record file {} generated...".format(args.target.split(".")[0] + ".rec"))