/ pytlib / data_loading / sources / kitti_source.py
kitti_source.py
  1  from builtins import object
  2  import os
  3  from image.box import Box
  4  from image.frame import Frame
  5  from image.object import Object
  6  from os import listdir
  7  from os.path import isfile, join, basename, splitext
  8  from data_loading.sources.source import Source
  9  from collections import defaultdict
 10  from interface import Interface, implements
 11  import numpy as np
 12  
 13  class KITTILabel(object):
 14      @classmethod
 15      def labels_from_file(cls,filename):
 16          labels = []
 17          with open(filename,'r') as f:
 18              for line in f:
 19                  labels.append(cls(line))
 20          return labels
 21      
 22      def __init__(self,line):
 23          linearr = line.split(' ')
 24          self.frame_idx = linearr[0]
 25          self.track_idx = linearr[1]
 26          self.type = linearr[2]
 27          self.truncated = linearr[3]
 28          self.occluded = linearr[4]
 29          self.bbox = [float(x) for x in linearr[6:10]]
 30          
 31      def to_object(self):
 32          box_format = [self.bbox[0],self.bbox[1],self.bbox[2],self.bbox[3]]
 33          return Object(Box.from_single_array(box_format),self.track_idx,self.type)
 34  
 35  class KITTISource(implements(Source)):
 36  
 37      image_dir = 'image_02'
 38      label_dir = 'label_02'
 39      calib_dir = 'calib'
 40  
 41      def __init__(self,dir_path,max_frames=float("inf")):
 42          self.dir_path = dir_path
 43          self.frames = []
 44          self.max_frames = max_frames
 45  
 46          self.__load_frames(dir_path)
 47          self.size = len(self.frames)
 48          self.cur = 0
 49  
 50      def __load_camera_matrix_from_calib(self,calib_file,line_prefix='P2'):
 51          # only load the p_rect calibration
 52          with open(calib_file, 'r') as f:
 53              lines = f.readlines()
 54              lines = [line.rstrip() for line in lines]
 55              mat = None
 56              for line in lines:
 57                  nline = line.split(': ')
 58                  if nline[0]==line_prefix:
 59                      mat = nline[1].split(' ')
 60                      mat = np.array([float(r) for r in mat], dtype=float)
 61                      mat = mat.reshape((3,4))[0:3, 0:3]
 62                      break
 63          return mat
 64  
 65  
 66      def __load_labelled_frames(self,frame_dir,labels_file,calib_file=None):
 67          files = [f for f in listdir(frame_dir) if isfile(join(frame_dir, f))]
 68          labels = KITTILabel.labels_from_file(labels_file) if labels_file is not None else []
 69          calibration_mat = None
 70          if os.path.isfile(calib_file):
 71              calibration_mat = self.__load_camera_matrix_from_calib(calib_file)
 72          frames = []
 73          sorted_files = sorted(files, key=lambda x: int(splitext(basename(x))[0]))
 74          for f in sorted_files:
 75              file_index = int(splitext(basename(f))[0])
 76              objects = []
 77              for l in labels:
 78                  if int(l.frame_idx) == file_index:
 79                      objects.append(l.to_object())
 80              frames.append(Frame(join(frame_dir,f),objs=objects,calib_mat=calibration_mat))
 81          return frames
 82  
 83      # KITTI images files are numerical only, ie: 00001, 00002 etc...
 84      def __validate_file_name(self,file_name):
 85          return file_name.isdigit()
 86  
 87      # assumes image dirs are of type image_xx and labels are label_xx.txt
 88      def __load_frames(self,dir_path):
 89          imagedir_full = os.path.join(dir_path,KITTISource.image_dir)
 90          labeldir_full = os.path.join(dir_path,KITTISource.label_dir)
 91          calibdir_full = os.path.join(dir_path,KITTISource.calib_dir)
 92          assert os.path.exists(imagedir_full), "Cannot find image dir at {}".format(imagedir_full)
 93          assert os.path.exists(labeldir_full), "Cannot find image dir at {}".format(labeldir_full)
 94          # assert os.path.exists(calibdir_full), "Cannot find image dir at {}".format(calibdir_full)
 95          for item in listdir(imagedir_full):
 96              if self.__validate_file_name(item):
 97                  label_path = os.path.join(labeldir_full,item+'.txt')
 98                  calib_path = os.path.join(calibdir_full,item+'.txt')
 99                  new_frames = self.__load_labelled_frames(os.path.join(imagedir_full,item),label_path,calib_path)
100                  if len(self.frames) >= self.max_frames:
101                      return
102                  self.frames.extend(new_frames[0:min(len(new_frames),self.max_frames - len(self.frames))])
103  
104  
105      def __next__(self):
106          if self.cur >= len(self.frames):
107              raise StopIteration
108          else:
109              ret = self.frames[self.cur]
110              self.cur+=1
111              return ret
112  
113      def __iter__(self):
114          return self
115  
116      def __len__(self):
117          return len(self.frames)
118  
119      def __getitem__(self,index):
120          return self.frames[index]
121  
122      def reset(self):
123          self.cur = 0