/ src / python / txtai / pipeline / image / objects.py
objects.py
 1  """
 2  Objects module
 3  """
 4  
 5  # Conditional import
 6  try:
 7      from PIL import Image
 8  
 9      PIL = True
10  except ImportError:
11      PIL = False
12  
13  from ..hfpipeline import HFPipeline
14  
15  
16  class Objects(HFPipeline):
17      """
18      Applies object detection models to images. Supports both object detection models and image classification models.
19      """
20  
21      def __init__(self, path=None, quantize=False, gpu=True, model=None, classification=False, threshold=0.9, **kwargs):
22          if not PIL:
23              raise ImportError('Objects pipeline is not available - install "pipeline" extra to enable')
24  
25          super().__init__("image-classification" if classification else "object-detection", path, quantize, gpu, model, **kwargs)
26  
27          self.classification = classification
28          self.threshold = threshold
29  
30      def __call__(self, images, flatten=False, workers=0):
31          """
32          Applies object detection/image classification models to images. Returns a list of (label, score).
33  
34          This method supports a single image or a list of images. If the input is an image, the return
35          type is a 1D list of (label, score). If text is a list, a 2D list of (label, score) is
36          returned with a row per image.
37  
38          Args:
39              images: image|list
40              flatten: flatten output to a list of objects
41              workers: number of concurrent workers to use for processing data, defaults to None
42  
43          Returns:
44              list of (label, score)
45          """
46  
47          # Convert single element to list
48          values = [images] if not isinstance(images, list) else images
49  
50          # Open images if file strings
51          values = [Image.open(image) if isinstance(image, str) else image for image in values]
52  
53          # Run pipeline
54          results = (
55              self.pipeline(values, num_workers=workers)
56              if self.classification
57              else self.pipeline(values, threshold=self.threshold, num_workers=workers)
58          )
59  
60          # Build list of (id, score)
61          outputs = []
62          for result in results:
63              # Convert to (label, score) tuples
64              result = [(x["label"], x["score"]) for x in result if x["score"] > self.threshold]
65  
66              # Sort by score descending
67              result = sorted(result, key=lambda x: x[1], reverse=True)
68  
69              # Deduplicate labels
70              unique = set()
71              elements = []
72              for label, score in result:
73                  if label not in unique:
74                      elements.append(label if flatten else (label, score))
75                      unique.add(label)
76  
77              outputs.append(elements)
78  
79          # Return single element if single element passed in
80          return outputs[0] if not isinstance(images, list) else outputs