/ pytlib / data_loading / sources / stanford_cars_source.py
stanford_cars_source.py
 1  from __future__ import print_function
 2  from interface import Interface, implements
 3  from image.box import Box
 4  from image.frame import Frame
 5  from image.object import Object
 6  from data_loading.sources.source import Source
 7  import scipy.io
 8  import os
 9  
10  # images can be downloaded here: http://imagenet.stanford.edu/internal/car196/cars_train.tgz
11  class StanfordCarsSource(implements(Source)):
12      def __init__(self,cars_dir,labels_mat):
13          self.frames = []
14          self.__load_frames(cars_dir,labels_mat)
15          self.cur = 0
16  
17      def __load_frames(self,cars_dir,labels_mat):
18          print('Loading Stanford Cars Frames')
19          labels = scipy.io.loadmat(labels_mat)['annotations'][0]
20          # load frames with labels
21          for label in labels:
22              if len(label)==5:
23                  xmin, ymin, xmax, ymax, path = label
24              elif len(label)==6:
25                  xmin, ymin, xmax, ymax, _, path = label
26              else:
27                  assert False, 'unable to parse label!'
28              box = Box(float(xmin[0][0]),float(ymin[0][0]),float(xmax[0][0]),float(ymax[0][0]))
29              obj = Object(box,obj_type='car')
30              image_path = os.path.join(cars_dir,path[0])
31              self.frames.append(Frame(image_path,[obj]))
32  
33      def __next__(self):
34          if self.cur >= len(self.frames):
35              raise StopIteration
36          else:
37              ret = self.frames[self.cur]
38              self.cur+=1
39              return ret
40  
41      def __iter__(self):
42          return self
43  
44      def __len__(self):
45          return len(self.frames)
46  
47      def __getitem__(self,index):
48          return self.frames[index]
49