stanford_cars_source.py
1 from __future__ import print_function 2 from interface import Interface, implements 3 from image.box import Box 4 from image.frame import Frame 5 from image.object import Object 6 from data_loading.sources.source import Source 7 import scipy.io 8 import os 9 10 # images can be downloaded here: http://imagenet.stanford.edu/internal/car196/cars_train.tgz 11 class StanfordCarsSource(implements(Source)): 12 def __init__(self,cars_dir,labels_mat): 13 self.frames = [] 14 self.__load_frames(cars_dir,labels_mat) 15 self.cur = 0 16 17 def __load_frames(self,cars_dir,labels_mat): 18 print('Loading Stanford Cars Frames') 19 labels = scipy.io.loadmat(labels_mat)['annotations'][0] 20 # load frames with labels 21 for label in labels: 22 if len(label)==5: 23 xmin, ymin, xmax, ymax, path = label 24 elif len(label)==6: 25 xmin, ymin, xmax, ymax, _, path = label 26 else: 27 assert False, 'unable to parse label!' 28 box = Box(float(xmin[0][0]),float(ymin[0][0]),float(xmax[0][0]),float(ymax[0][0])) 29 obj = Object(box,obj_type='car') 30 image_path = os.path.join(cars_dir,path[0]) 31 self.frames.append(Frame(image_path,[obj])) 32 33 def __next__(self): 34 if self.cur >= len(self.frames): 35 raise StopIteration 36 else: 37 ret = self.frames[self.cur] 38 self.cur+=1 39 return ret 40 41 def __iter__(self): 42 return self 43 44 def __len__(self): 45 return len(self.frames) 46 47 def __getitem__(self,index): 48 return self.frames[index] 49