# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import print_function import os import xml.etree.ElementTree as ET import cv2 import numpy as np from imdb import Imdb class PascalVoc(Imdb): """ Implementation of Imdb for Pascal VOC datasets Parameters: ---------- image_set : str set to be used, can be train, val, trainval, test year : str year of dataset, can be 2007, 2010, 2012... devkit_path : str devkit path of VOC dataset shuffle : boolean whether to initial shuffle the image list is_train : boolean if true, will load annotations """ def __init__( self, image_set, year, devkit_path, shuffle=False, is_train=False, names="pascal_voc.names" ): super(PascalVoc, self).__init__("voc_" + year + "_" + image_set) self.image_set = image_set self.year = year self.devkit_path = devkit_path self.data_path = os.path.join(devkit_path, "VOC" + year) self.extension = ".jpg" self.is_train = is_train self.classes = self._load_class_names(names, os.path.dirname(__file__)) self.config = { "use_difficult": True, "comp_id": "comp4", } self.num_classes = len(self.classes) self.image_set_index = self._load_image_set_index(shuffle) self.num_images = len(self.image_set_index) if self.is_train: self.labels = self._load_image_labels() @property def cache_path(self): """ make a directory to store all caches Returns: --------- cache path """ cache_path = os.path.join(os.path.dirname(__file__), "..", "cache") if not os.path.exists(cache_path): os.mkdir(cache_path) return cache_path def _load_image_set_index(self, shuffle): """ find out which indexes correspond to given image set (train or val) Parameters: ---------- shuffle : boolean whether to shuffle the image list Returns: ---------- entire list of images specified in the setting """ image_set_index_file = os.path.join( self.data_path, "ImageSets", "Main", self.image_set + ".txt" ) assert os.path.exists(image_set_index_file), "Path does not exist: {}".format( image_set_index_file ) with open(image_set_index_file) as f: image_set_index = [x.strip() for x in f.readlines()] if shuffle: np.random.shuffle(image_set_index) return image_set_index def image_path_from_index(self, index): """ given image index, find out full path Parameters: ---------- index: int index of a specific image Returns: ---------- full path of this image """ assert self.image_set_index is not None, "Dataset not initialized" name = self.image_set_index[index] image_file = os.path.join(self.data_path, "JPEGImages", name + self.extension) assert os.path.exists(image_file), "Path does not exist: {}".format(image_file) return image_file def label_from_index(self, index): """ given image index, return preprocessed ground-truth Parameters: ---------- index: int index of a specific image Returns: ---------- ground-truths of this image """ assert self.labels is not None, "Labels not processed" return self.labels[index] def _label_path_from_index(self, index): """ given image index, find out annotation path Parameters: ---------- index: int index of a specific image Returns: ---------- full path of annotation file """ label_file = os.path.join(self.data_path, "Annotations", index + ".xml") assert os.path.exists(label_file), "Path does not exist: {}".format(label_file) return label_file def _load_image_labels(self): """ preprocess all ground-truths Returns: ---------- labels packed in [num_images x max_num_objects x 5] tensor """ temp = [] # load ground-truth from xml annotations for idx in self.image_set_index: label_file = self._label_path_from_index(idx) tree = ET.parse(label_file) root = tree.getroot() size = root.find("size") width = float(size.find("width").text) height = float(size.find("height").text) label = [] for obj in root.iter("object"): difficult = int(obj.find("difficult").text) # if not self.config['use_difficult'] and difficult == 1: # continue cls_name = obj.find("name").text if cls_name not in self.classes: continue cls_id = self.classes.index(cls_name) xml_box = obj.find("bndbox") xmin = float(xml_box.find("xmin").text) / width ymin = float(xml_box.find("ymin").text) / height xmax = float(xml_box.find("xmax").text) / width ymax = float(xml_box.find("ymax").text) / height label.append([cls_id, xmin, ymin, xmax, ymax]) temp.append(np.array(label)) return temp def get_result_file_template(self): """ this is a template VOCdevkit/results/VOC2007/Main/_det_test_aeroplane.txt Returns: ---------- a string template """ res_file_folder = os.path.join(self.devkit_path, "results", "VOC" + self.year, "Main") comp_id = self.config["comp_id"] filename = comp_id + "_det_" + self.image_set + "_{:s}.txt" path = os.path.join(res_file_folder, filename) return path def write_pascal_results(self, all_boxes): """ write results files in pascal devkit path Parameters: ---------- all_boxes: list boxes to be processed [bbox, confidence] Returns: ---------- None """ for cls_ind, cls in enumerate(self.classes): print("Writing {} VOC results file".format(cls)) filename = self.get_result_file_template().format(cls) with open(filename, "wt") as f: for im_ind, index in enumerate(self.image_set_index): dets = all_boxes[im_ind] if dets.shape[0] < 1: continue h, w = self._get_imsize(self.image_path_from_index(im_ind)) # the VOCdevkit expects 1-based indices for k in range(dets.shape[0]): if int(dets[k, 0]) == cls_ind: f.write( "{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n".format( index, dets[k, 1], int(dets[k, 2] * w) + 1, int(dets[k, 3] * h) + 1, int(dets[k, 4] * w) + 1, int(dets[k, 5] * h) + 1, ) ) def _get_imsize(self, im_name): """ get image size info Returns: ---------- tuple of (height, width) """ img = cv2.imread(im_name) return (img.shape[0], img.shape[1])