/ Lobe_Cat_Detector / lobe-cat-detector.py
lobe-cat-detector.py
  1  # SPDX-FileCopyrightText: 2021 Melissa LeBlanc-Williams for Adafruit Industries
  2  #
  3  # SPDX-License-Identifier: MIT
  4  
  5  import time
  6  from enum import Enum, auto
  7  import board
  8  from digitalio import DigitalInOut, Direction, Pull
  9  import picamera
 10  import io
 11  from PIL import Image
 12  from lobe import ImageModel
 13  import os
 14  import adafruit_dotstar
 15  from datetime import datetime
 16  import pwmio
 17  from adafruit_motor import servo
 18  
 19  LABEL_CAT = "Cat"
 20  LABEL_MULTI_CAT = "Multiple Cats"
 21  LABEL_NOTHING = "No Cats"
 22  SERVO_PIN = board.D12
 23  WARNING_COUNT = 3
 24  
 25  pwm = pwmio.PWMOut(SERVO_PIN, duty_cycle=0, frequency=50)
 26  servo = servo.Servo(pwm, min_pulse=400, max_pulse=2400)
 27  
 28  # Boiler Plate code for buttons and joystick on the braincraft
 29  BUTTON_PIN = board.D17
 30  JOYDOWN_PIN = board.D27
 31  JOYLEFT_PIN = board.D22
 32  JOYUP_PIN = board.D23
 33  JOYRIGHT_PIN = board.D24
 34  JOYSELECT_PIN = board.D16
 35  
 36  buttons = [BUTTON_PIN, JOYUP_PIN, JOYDOWN_PIN,
 37             JOYLEFT_PIN, JOYRIGHT_PIN, JOYSELECT_PIN]
 38  
 39  for i, pin in enumerate(buttons):
 40      buttons[i] = DigitalInOut(pin)
 41      buttons[i].direction = Direction.INPUT
 42      buttons[i].pull = Pull.UP
 43  button, joyup, joydown, joyleft, joyright, joyselect = buttons
 44  
 45  
 46  class Input(Enum):
 47      BUTTON = auto()
 48      UP = auto()
 49      DOWN = auto()
 50      LEFT = auto()
 51      RIGHT = auto()
 52      SELECT = auto()
 53  
 54  
 55  def get_inputs():
 56      inputs = []
 57      if not button.value:
 58          inputs.append(Input.BUTTON)
 59      if not joyup.value:
 60          inputs.append(Input.UP)
 61      if not joydown.value:
 62          inputs.append(Input.DOWN)
 63      if not joyleft.value:
 64          inputs.append(Input.LEFT)
 65      if not joyright.value:
 66          inputs.append(Input.RIGHT)
 67      if not joyselect.value:
 68          inputs.append(Input.SELECT)
 69      return inputs
 70  
 71  DOTSTAR_DATA = board.D5
 72  DOTSTAR_CLOCK = board.D6
 73  
 74  RED = (0, 0, 255)
 75  GREEN = (255, 0, 0)
 76  OFF = (0, 0, 0)
 77  
 78  dots = adafruit_dotstar.DotStar(DOTSTAR_CLOCK, DOTSTAR_DATA, 3, brightness=0.1)
 79  
 80  jingle_count = 0
 81  
 82  def color_fill(color, wait):
 83      dots.fill(color)
 84      dots.show()
 85      time.sleep(wait)
 86  
 87  def jingle_keys(jingle_hard=False):
 88      global jingle_count
 89      jingle_count += 1
 90      if jingle_count > WARNING_COUNT:
 91          jingle_hard = True
 92      delay = 0.5 if jingle_hard else 2
 93      loop = 5 if jingle_hard else 1
 94      travel = 180 if jingle_hard else 135
 95      for _ in range(0, loop):
 96          for angle in (0, travel):
 97              servo.angle = angle
 98              time.sleep(delay)
 99      servo.angle = None
100  
101  def main():
102      global jingle_count
103      model = ImageModel.load('~/model')
104  
105      # Check if there is a folder to keep the retraining data, if it there isn't make it
106      if (not os.path.exists('./retraining_data')):
107          os.mkdir('./retraining_data')
108  
109      with picamera.PiCamera(resolution=(224, 224), framerate=30) as camera:
110          stream = io.BytesIO()
111          camera.start_preview()
112          # Camera warm-up time
113          time.sleep(2)
114          label = ''
115          while True:
116              stream.seek(0)
117              camera.annotate_text = None
118              camera.capture(stream, format='jpeg')
119              camera.annotate_text = label
120              img = Image.open(stream)
121              result = model.predict(img)
122              label = result.prediction
123              confidence = result.labels[0][1]
124              camera.annotate_text = label
125              print(f'\rLabel: {label} | Confidence: {confidence*100: .2f}%', end='', flush=True)
126  
127              # Check if the current label is package and that the label has changed since last tine the code ran
128              if label == LABEL_CAT:
129                  # Make Servo Jingle Keys
130                  jingle_keys()
131              elif label == LABEL_MULTI_CAT:
132                  jingle_keys(True)
133              elif label == LABEL_NOTHING:
134                  jingle_count = 0
135  
136              time.sleep(0.5)
137  
138              inputs = get_inputs()
139              # Check if the joystick is pushed up
140              if (Input.UP in inputs):
141                  color_fill(GREEN, 0)
142                  # Check if there is a folder to keep the retraining data, if it there isn't make it
143                  if (not os.path.exists(f'./retraining_data/{label}')):
144                      os.mkdir(f'./retraining_data/{label}')
145                  # Remove the text annotation
146                  camera.annotate_text = None
147  
148                  # File name
149                  name = datetime.now()
150                  # Save the current frame
151                  camera.capture(
152                      os.path.join(
153                          f'./retraining_data/{label}', 
154                          f'{datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}.jpg'
155                      )
156                  )
157                  
158                  color_fill(OFF, 0)
159  
160              # Check if the joystick is pushed down
161              elif (Input.DOWN in inputs or Input.BUTTON in inputs):
162                  color_fill(RED, 0)
163                  # Remove the text annotation
164                  camera.annotate_text = None
165                  # Save the current frame to the top level retraining directory
166                  camera.capture(
167                      os.path.join(
168                          f'./retraining_data',
169                          f'{datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}.jpg'
170                      )
171                  )
172                  color_fill(OFF, 0)
173  
174  
175  if __name__ == '__main__':
176      try:
177          print(f"Predictions starting, to stop press \"CTRL+C\"")
178          main()
179      except KeyboardInterrupt:
180          print("")
181          print(f"Caught interrupt, exiting...")