kitti_source.py
1 from builtins import object 2 import os 3 from image.box import Box 4 from image.frame import Frame 5 from image.object import Object 6 from os import listdir 7 from os.path import isfile, join, basename, splitext 8 from data_loading.sources.source import Source 9 from collections import defaultdict 10 from interface import Interface, implements 11 import numpy as np 12 13 class KITTILabel(object): 14 @classmethod 15 def labels_from_file(cls,filename): 16 labels = [] 17 with open(filename,'r') as f: 18 for line in f: 19 labels.append(cls(line)) 20 return labels 21 22 def __init__(self,line): 23 linearr = line.split(' ') 24 self.frame_idx = linearr[0] 25 self.track_idx = linearr[1] 26 self.type = linearr[2] 27 self.truncated = linearr[3] 28 self.occluded = linearr[4] 29 self.bbox = [float(x) for x in linearr[6:10]] 30 31 def to_object(self): 32 box_format = [self.bbox[0],self.bbox[1],self.bbox[2],self.bbox[3]] 33 return Object(Box.from_single_array(box_format),self.track_idx,self.type) 34 35 class KITTISource(implements(Source)): 36 37 image_dir = 'image_02' 38 label_dir = 'label_02' 39 calib_dir = 'calib' 40 41 def __init__(self,dir_path,max_frames=float("inf")): 42 self.dir_path = dir_path 43 self.frames = [] 44 self.max_frames = max_frames 45 46 self.__load_frames(dir_path) 47 self.size = len(self.frames) 48 self.cur = 0 49 50 def __load_camera_matrix_from_calib(self,calib_file,line_prefix='P2'): 51 # only load the p_rect calibration 52 with open(calib_file, 'r') as f: 53 lines = f.readlines() 54 lines = [line.rstrip() for line in lines] 55 mat = None 56 for line in lines: 57 nline = line.split(': ') 58 if nline[0]==line_prefix: 59 mat = nline[1].split(' ') 60 mat = np.array([float(r) for r in mat], dtype=float) 61 mat = mat.reshape((3,4))[0:3, 0:3] 62 break 63 return mat 64 65 66 def __load_labelled_frames(self,frame_dir,labels_file,calib_file=None): 67 files = [f for f in listdir(frame_dir) if isfile(join(frame_dir, f))] 68 labels = KITTILabel.labels_from_file(labels_file) if labels_file is not None else [] 69 calibration_mat = None 70 if os.path.isfile(calib_file): 71 calibration_mat = self.__load_camera_matrix_from_calib(calib_file) 72 frames = [] 73 sorted_files = sorted(files, key=lambda x: int(splitext(basename(x))[0])) 74 for f in sorted_files: 75 file_index = int(splitext(basename(f))[0]) 76 objects = [] 77 for l in labels: 78 if int(l.frame_idx) == file_index: 79 objects.append(l.to_object()) 80 frames.append(Frame(join(frame_dir,f),objs=objects,calib_mat=calibration_mat)) 81 return frames 82 83 # KITTI images files are numerical only, ie: 00001, 00002 etc... 84 def __validate_file_name(self,file_name): 85 return file_name.isdigit() 86 87 # assumes image dirs are of type image_xx and labels are label_xx.txt 88 def __load_frames(self,dir_path): 89 imagedir_full = os.path.join(dir_path,KITTISource.image_dir) 90 labeldir_full = os.path.join(dir_path,KITTISource.label_dir) 91 calibdir_full = os.path.join(dir_path,KITTISource.calib_dir) 92 assert os.path.exists(imagedir_full), "Cannot find image dir at {}".format(imagedir_full) 93 assert os.path.exists(labeldir_full), "Cannot find image dir at {}".format(labeldir_full) 94 # assert os.path.exists(calibdir_full), "Cannot find image dir at {}".format(calibdir_full) 95 for item in listdir(imagedir_full): 96 if self.__validate_file_name(item): 97 label_path = os.path.join(labeldir_full,item+'.txt') 98 calib_path = os.path.join(calibdir_full,item+'.txt') 99 new_frames = self.__load_labelled_frames(os.path.join(imagedir_full,item),label_path,calib_path) 100 if len(self.frames) >= self.max_frames: 101 return 102 self.frames.extend(new_frames[0:min(len(new_frames),self.max_frames - len(self.frames))]) 103 104 105 def __next__(self): 106 if self.cur >= len(self.frames): 107 raise StopIteration 108 else: 109 ret = self.frames[self.cur] 110 self.cur+=1 111 return ret 112 113 def __iter__(self): 114 return self 115 116 def __len__(self): 117 return len(self.frames) 118 119 def __getitem__(self,index): 120 return self.frames[index] 121 122 def reset(self): 123 self.cur = 0