/ spoolman / database / spool.py
spool.py
  1  """Helper functions for interacting with spool database objects."""
  2  
  3  import logging
  4  from collections.abc import Sequence
  5  from datetime import datetime, timezone
  6  
  7  import sqlalchemy
  8  from sqlalchemy import case, func
  9  from sqlalchemy.exc import NoResultFound
 10  from sqlalchemy.ext.asyncio import AsyncSession
 11  from sqlalchemy.orm import contains_eager, joinedload
 12  from sqlalchemy.sql.functions import coalesce
 13  
 14  from spoolman.api.v1.models import EventType, Spool, SpoolEvent
 15  from spoolman.database import filament, models
 16  from spoolman.database.utils import (
 17      SortOrder,
 18      add_where_clause_int,
 19      add_where_clause_int_opt,
 20      add_where_clause_str,
 21      add_where_clause_str_opt,
 22      parse_nested_field,
 23  )
 24  from spoolman.exceptions import ItemCreateError, ItemNotFoundError, SpoolMeasureError
 25  from spoolman.math import weight_from_length
 26  from spoolman.ws import websocket_manager
 27  
 28  logger = logging.getLogger(__name__)
 29  
 30  
 31  def utc_timezone_naive(dt: datetime) -> datetime:
 32      """Convert a datetime object to UTC and remove timezone info."""
 33      return dt.astimezone(tz=timezone.utc).replace(tzinfo=None)
 34  
 35  
 36  async def create(
 37      *,
 38      db: AsyncSession,
 39      filament_id: int,
 40      remaining_weight: float | None = None,
 41      initial_weight: float | None = None,
 42      spool_weight: float | None = None,
 43      used_weight: float | None = None,
 44      first_used: datetime | None = None,
 45      last_used: datetime | None = None,
 46      price: float | None = None,
 47      location: str | None = None,
 48      lot_nr: str | None = None,
 49      comment: str | None = None,
 50      archived: bool = False,
 51      extra: dict[str, str] | None = None,
 52  ) -> models.Spool:
 53      """Add a new spool to the database. Leave weight empty to assume full spool."""
 54      filament_item = await filament.get_by_id(db, filament_id)
 55  
 56      # Set spool_weight to spool_weight if spool_weight is not null and spool_weight not provided
 57      if spool_weight is None and filament_item.spool_weight is not None:
 58          spool_weight = filament_item.spool_weight
 59  
 60      # Calculate initial_weight if not provided
 61      if initial_weight is None and filament_item.weight is not None:
 62          initial_weight = filament_item.weight
 63  
 64      if used_weight is None:
 65          if remaining_weight is not None:
 66              if initial_weight is None or initial_weight == 0:
 67                  raise ItemCreateError(
 68                      "remaining_weight can only be used if the initial_weight is "
 69                      "defined or the filament has a weight set.",
 70                  )
 71              used_weight = max(initial_weight - remaining_weight, 0)
 72          else:
 73              used_weight = 0
 74  
 75      # Convert datetime values to UTC and remove timezone info
 76      if first_used is not None:
 77          first_used = utc_timezone_naive(first_used)
 78      if last_used is not None:
 79          last_used = utc_timezone_naive(last_used)
 80  
 81      spool = models.Spool(
 82          filament=filament_item,
 83          registered=datetime.utcnow().replace(microsecond=0),
 84          initial_weight=initial_weight,
 85          spool_weight=spool_weight,
 86          used_weight=used_weight,
 87          price=price,
 88          first_used=first_used,
 89          last_used=last_used,
 90          location=location,
 91          lot_nr=lot_nr,
 92          comment=comment,
 93          archived=archived,
 94          extra=[models.SpoolField(key=k, value=v) for k, v in (extra or {}).items()],
 95      )
 96      db.add(spool)
 97      await db.commit()
 98      await spool_changed(spool, EventType.ADDED)
 99      return spool
100  
101  
102  async def get_by_id(db: AsyncSession, spool_id: int) -> models.Spool:
103      """Get a spool object from the database by the unique ID."""
104      spool = await db.get(
105          models.Spool,
106          spool_id,
107          options=[joinedload("*")],  # Load all nested objects as well
108      )
109      if spool is None:
110          raise ItemNotFoundError(f"No spool with ID {spool_id} found.")
111      return spool
112  
113  
114  async def find(  # noqa: C901, PLR0912
115      *,
116      db: AsyncSession,
117      filament_name: str | None = None,
118      filament_id: int | Sequence[int] | None = None,
119      filament_material: str | None = None,
120      vendor_name: str | None = None,
121      vendor_id: int | Sequence[int] | None = None,
122      location: str | None = None,
123      lot_nr: str | None = None,
124      allow_archived: bool = False,
125      sort_by: dict[str, SortOrder] | None = None,
126      limit: int | None = None,
127      offset: int = 0,
128  ) -> tuple[list[models.Spool], int]:
129      """Find a list of spool objects by search criteria.
130  
131      Sort by a field by passing a dict with the field name as key and the sort order as value.
132      The field name can contain nested fields, e.g. filament.name.
133  
134      Returns a tuple containing the list of items and the total count of matching items.
135      """
136      stmt = (
137          sqlalchemy.select(models.Spool)
138          .join(models.Spool.filament, isouter=True)
139          .join(models.Filament.vendor, isouter=True)
140          .options(contains_eager(models.Spool.filament).contains_eager(models.Filament.vendor))
141      )
142  
143      stmt = add_where_clause_int(stmt, models.Spool.filament_id, filament_id)
144      stmt = add_where_clause_int_opt(stmt, models.Filament.vendor_id, vendor_id)
145      stmt = add_where_clause_str(stmt, models.Vendor.name, vendor_name)
146      stmt = add_where_clause_str_opt(stmt, models.Filament.name, filament_name)
147      stmt = add_where_clause_str_opt(stmt, models.Filament.material, filament_material)
148      stmt = add_where_clause_str_opt(stmt, models.Spool.location, location)
149      stmt = add_where_clause_str_opt(stmt, models.Spool.lot_nr, lot_nr)
150  
151      if not allow_archived:
152          # Since the archived field is nullable, and default is false, we need to check for both false or null
153          stmt = stmt.where(
154              sqlalchemy.or_(
155                  models.Spool.archived.is_(False),
156                  models.Spool.archived.is_(None),
157              ),
158          )
159  
160      total_count = None
161  
162      if limit is not None:
163          total_count_stmt = stmt.with_only_columns(func.count(), maintain_column_froms=True)
164          total_count = (await db.execute(total_count_stmt)).scalar()
165  
166          stmt = stmt.offset(offset).limit(limit)
167  
168      if sort_by is not None:
169          for fieldstr, order in sort_by.items():
170              sorts = []
171              if fieldstr == "remaining_weight":
172                  sorts.append(coalesce(models.Spool.initial_weight, models.Filament.weight) - models.Spool.used_weight)
173              elif fieldstr == "remaining_length":
174                  # Simplified weight -> length formula. Absolute value is not correct but the proportionality is still
175                  # kept, which means the sort order is correct.
176                  sorts.append(
177                      (coalesce(models.Spool.initial_weight, models.Filament.weight) - models.Spool.used_weight)
178                      / models.Filament.density
179                      / (models.Filament.diameter * models.Filament.diameter),
180                  )
181              elif fieldstr == "used_length":
182                  sorts.append(
183                      models.Spool.used_weight
184                      / models.Filament.density
185                      / (models.Filament.diameter * models.Filament.diameter),
186                  )
187              elif fieldstr == "filament.combined_name":
188                  sorts.append(models.Vendor.name)
189                  sorts.append(models.Filament.name)
190              elif fieldstr == "price":
191                  sorts.append(coalesce(models.Spool.price, models.Filament.price))
192              else:
193                  sorts.append(parse_nested_field(models.Spool, fieldstr))
194  
195              if order == SortOrder.ASC:
196                  stmt = stmt.order_by(*(f.asc() for f in sorts))
197              elif order == SortOrder.DESC:
198                  stmt = stmt.order_by(*(f.desc() for f in sorts))
199  
200      rows = await db.execute(
201          stmt,
202          execution_options={"populate_existing": True},
203      )
204      result = list(rows.unique().scalars().all())
205      if total_count is None:
206          total_count = len(result)
207  
208      return result, total_count
209  
210  
211  async def update(
212      *,
213      db: AsyncSession,
214      spool_id: int,
215      data: dict,
216  ) -> models.Spool:
217      """Update the fields of a spool object."""
218      spool = await get_by_id(db, spool_id)
219      for k, v in data.items():
220          if k == "filament_id":
221              spool.filament = await filament.get_by_id(db, v)
222              # If there is no initial_weight, calculate it from the filament weight
223              if spool.initial_weight is None and spool.filament.weight is not None:
224                  spool.initial_weight = spool.filament.weight
225  
226          elif k == "remaining_weight":
227              if spool.initial_weight is None:
228                  raise ItemCreateError("remaining_weight can only be used if initial_weight is set.")
229              spool.used_weight = max(spool.initial_weight - v, 0)
230          elif isinstance(v, datetime):
231              setattr(spool, k, utc_timezone_naive(v))
232          elif k == "extra":
233              spool.extra = [f for f in spool.extra if f.key not in v]
234              spool.extra.extend([models.SpoolField(key=k, value=v) for k, v in v.items()])
235          else:
236              setattr(spool, k, v)
237      await db.commit()
238      await spool_changed(spool, EventType.UPDATED)
239      return spool
240  
241  
242  async def delete(db: AsyncSession, spool_id: int) -> None:
243      """Delete a spool object."""
244      spool = await get_by_id(db, spool_id)
245      await spool_changed(spool, EventType.DELETED)
246      await db.delete(spool)
247  
248  
249  async def clear_extra_field(db: AsyncSession, key: str) -> None:
250      """Delete all extra fields with a specific key."""
251      await db.execute(
252          sqlalchemy.delete(models.SpoolField).where(models.SpoolField.key == key),
253      )
254  
255  
256  async def use_weight_safe(db: AsyncSession, spool_id: int, weight: float) -> None:
257      """Consume filament from a spool by weight in a way that is safe against race conditions.
258  
259      Args:
260          db (AsyncSession): Database session
261          spool_id (int): Spool ID
262          weight (float): Filament weight to consume, in grams
263  
264      """
265      await db.execute(
266          sqlalchemy.update(models.Spool)
267          .where(models.Spool.id == spool_id)
268          .values(
269              used_weight=case(
270                  (models.Spool.used_weight + weight >= 0.0, models.Spool.used_weight + weight),
271                  else_=0.0,  # Set used_weight to 0 if the result would be negative
272              ),
273          ),
274      )
275  
276  
277  async def use_weight(db: AsyncSession, spool_id: int, weight: float) -> models.Spool:
278      """Consume filament from a spool by weight.
279  
280      Increases the used_weight attribute of the spool.
281      Updates the first_used and last_used attributes where appropriate.
282  
283      Args:
284          db (AsyncSession): Database session
285          spool_id (int): Spool ID
286          weight (float): Filament weight to consume, in grams
287  
288      Returns:
289          models.Spool: Updated spool object
290  
291      """
292      await use_weight_safe(db, spool_id, weight)
293  
294      spool = await get_by_id(db, spool_id)
295  
296      if spool.first_used is None:
297          spool.first_used = datetime.utcnow().replace(microsecond=0)
298      spool.last_used = datetime.utcnow().replace(microsecond=0)
299  
300      await db.commit()
301      await spool_changed(spool, EventType.UPDATED)
302      return spool
303  
304  
305  async def use_length(db: AsyncSession, spool_id: int, length: float) -> models.Spool:
306      """Consume filament from a spool by length.
307  
308      Increases the used_weight attribute of the spool.
309      Updates the first_used and last_used attributes where appropriate.
310  
311      Args:
312          db (AsyncSession): Database session
313          spool_id (int): Spool ID
314          length (float): Length of filament to consume, in mm
315  
316      Returns:
317          models.Spool: Updated spool object
318  
319      """
320      # Get filament diameter and density
321      result = await db.execute(
322          sqlalchemy.select(models.Filament.diameter, models.Filament.density)
323          .join(models.Spool, models.Spool.filament_id == models.Filament.id)
324          .where(models.Spool.id == spool_id),
325      )
326      try:
327          filament_info = result.one()
328      except NoResultFound as exc:
329          raise ItemNotFoundError("Filament not found for spool.") from exc
330  
331      # Calculate and use weight
332      weight = weight_from_length(
333          length=length,
334          diameter=filament_info[0],
335          density=filament_info[1],
336      )
337      await use_weight_safe(db, spool_id, weight)
338  
339      # Get spool with new weight and update first_used and last_used
340      spool = await get_by_id(db, spool_id)
341  
342      if spool.first_used is None:
343          spool.first_used = datetime.utcnow().replace(microsecond=0)
344      spool.last_used = datetime.utcnow().replace(microsecond=0)
345  
346      await db.commit()
347      await spool_changed(spool, EventType.UPDATED)
348      return spool
349  
350  
351  async def measure(db: AsyncSession, spool_id: int, weight: float) -> models.Spool:
352      """Record usage based on current gross weight of spool.
353  
354      Increases the used_weight attribute of the spool.
355      Updates the first_used and last_used attributes where appropriate.
356  
357      Args:
358          db (AsyncSession): Database session
359          spool_id (int): Spool ID
360          weight (float): Length of filament to consume, in mm
361  
362      Returns:
363          models.Spool: Updated spool object
364  
365      """
366      spool_result = await db.execute(
367          sqlalchemy.select(models.Spool.initial_weight, models.Spool.used_weight, models.Spool.spool_weight).where(
368              models.Spool.id == spool_id,
369          ),
370      )
371  
372      try:
373          spool_info = spool_result.one()
374      except NoResultFound as exc:
375          raise SpoolMeasureError("Spool not found.") from exc
376  
377      initial_weight = spool_info[0]
378      spool_weight = spool_info[2]
379      if initial_weight is None or initial_weight == 0 or spool_weight is None or spool_weight == 0:
380          # Get filament weight and spool_weight
381          result = await db.execute(
382              sqlalchemy.select(models.Filament.weight, models.Filament.spool_weight)
383              .join(models.Spool, models.Spool.filament_id == models.Filament.id)
384              .where(models.Spool.id == spool_id),
385          )
386          try:
387              filament_info = result.one()
388          except NoResultFound as exc:
389              raise ItemNotFoundError("Filament not found for spool.") from exc
390  
391          if spool_weight is None or spool_weight == 0:
392              spool_weight = filament_info[1]
393  
394          if initial_weight is None or initial_weight == 0:
395              initial_weight = filament_info[0] if filament_info[0] is not None else 0
396  
397      if initial_weight is None or initial_weight == 0:
398          raise SpoolMeasureError("Initial weight is not set.")
399  
400      initial_gross_weight = initial_weight + spool_weight
401  
402      # if the measurement is greater than the initial weight, set the initial weight to the measurement
403      if weight > initial_gross_weight:
404          return await reset_initial_weight(db, spool_id, weight - spool_weight)
405  
406      # Calculate the current net weight
407      current_use = initial_gross_weight - spool_info[1]
408  
409      # Calculate the weight used since last measure
410      weight_to_use = current_use - weight
411  
412      # If the measured weight is less than the empty weight, use the rest of the spool
413      if (initial_gross_weight - weight_to_use) < spool_weight:
414          weight_to_use = current_use - spool_weight
415  
416      return await use_weight(db, spool_id, weight_to_use)
417  
418  
419  async def find_locations(
420      *,
421      db: AsyncSession,
422  ) -> list[str]:
423      """Find a list of spool locations by searching for distinct values in the spool table."""
424      stmt = sqlalchemy.select(models.Spool.location).distinct()
425      rows = await db.execute(stmt)
426      return [row[0] for row in rows.all() if row[0] is not None]
427  
428  
429  async def find_lot_numbers(
430      *,
431      db: AsyncSession,
432  ) -> list[str]:
433      """Find a list of spool lot numbers by searching for distinct values in the spool table."""
434      stmt = sqlalchemy.select(models.Spool.lot_nr).distinct()
435      rows = await db.execute(stmt)
436      return [row[0] for row in rows.all() if row[0] is not None]
437  
438  
439  async def spool_changed(spool: models.Spool, typ: EventType) -> None:
440      """Notify websocket clients that a spool has changed."""
441      try:
442          await websocket_manager.send(
443              ("spool", str(spool.id)),
444              SpoolEvent(
445                  type=typ,
446                  resource="spool",
447                  date=datetime.utcnow(),
448                  payload=Spool.from_db(spool),
449              ),
450          )
451      except Exception:
452          # Important to have a catch-all here since we don't want to stop the call if this fails.
453          logger.exception("Failed to send websocket message")
454  
455  
456  async def reset_initial_weight(db: AsyncSession, spool_id: int, weight: float) -> models.Spool:
457      """Reset inital weight to new weight and used_weight to 0."""
458      spool = await get_by_id(db, spool_id)
459  
460      spool.initial_weight = weight
461      spool.used_weight = 0
462      await db.commit()
463      await spool_changed(spool, EventType.UPDATED)
464      return spool
465  
466  
467  async def rename_location(
468      *,
469      db: AsyncSession,
470      current_name: str,
471      new_name: str,
472  ) -> None:
473      """Rename all spools with the current location name to the new name."""
474      await db.execute(
475          sqlalchemy.update(models.Spool).where(models.Spool.location == current_name).values(location=new_name),
476      )