objects.py
1 """ 2 Objects module 3 """ 4 5 # Conditional import 6 try: 7 from PIL import Image 8 9 PIL = True 10 except ImportError: 11 PIL = False 12 13 from ..hfpipeline import HFPipeline 14 15 16 class Objects(HFPipeline): 17 """ 18 Applies object detection models to images. Supports both object detection models and image classification models. 19 """ 20 21 def __init__(self, path=None, quantize=False, gpu=True, model=None, classification=False, threshold=0.9, **kwargs): 22 if not PIL: 23 raise ImportError('Objects pipeline is not available - install "pipeline" extra to enable') 24 25 super().__init__("image-classification" if classification else "object-detection", path, quantize, gpu, model, **kwargs) 26 27 self.classification = classification 28 self.threshold = threshold 29 30 def __call__(self, images, flatten=False, workers=0): 31 """ 32 Applies object detection/image classification models to images. Returns a list of (label, score). 33 34 This method supports a single image or a list of images. If the input is an image, the return 35 type is a 1D list of (label, score). If text is a list, a 2D list of (label, score) is 36 returned with a row per image. 37 38 Args: 39 images: image|list 40 flatten: flatten output to a list of objects 41 workers: number of concurrent workers to use for processing data, defaults to None 42 43 Returns: 44 list of (label, score) 45 """ 46 47 # Convert single element to list 48 values = [images] if not isinstance(images, list) else images 49 50 # Open images if file strings 51 values = [Image.open(image) if isinstance(image, str) else image for image in values] 52 53 # Run pipeline 54 results = ( 55 self.pipeline(values, num_workers=workers) 56 if self.classification 57 else self.pipeline(values, threshold=self.threshold, num_workers=workers) 58 ) 59 60 # Build list of (id, score) 61 outputs = [] 62 for result in results: 63 # Convert to (label, score) tuples 64 result = [(x["label"], x["score"]) for x in result if x["score"] > self.threshold] 65 66 # Sort by score descending 67 result = sorted(result, key=lambda x: x[1], reverse=True) 68 69 # Deduplicate labels 70 unique = set() 71 elements = [] 72 for label, score in result: 73 if label not in unique: 74 elements.append(label if flatten else (label, score)) 75 unique.add(label) 76 77 outputs.append(elements) 78 79 # Return single element if single element passed in 80 return outputs[0] if not isinstance(images, list) else outputs