spool.py
1 """Helper functions for interacting with spool database objects.""" 2 3 import logging 4 from collections.abc import Sequence 5 from datetime import datetime, timezone 6 7 import sqlalchemy 8 from sqlalchemy import case, func 9 from sqlalchemy.exc import NoResultFound 10 from sqlalchemy.ext.asyncio import AsyncSession 11 from sqlalchemy.orm import contains_eager, joinedload 12 from sqlalchemy.sql.functions import coalesce 13 14 from spoolman.api.v1.models import EventType, Spool, SpoolEvent 15 from spoolman.database import filament, models 16 from spoolman.database.utils import ( 17 SortOrder, 18 add_where_clause_int, 19 add_where_clause_int_opt, 20 add_where_clause_str, 21 add_where_clause_str_opt, 22 parse_nested_field, 23 ) 24 from spoolman.exceptions import ItemCreateError, ItemNotFoundError, SpoolMeasureError 25 from spoolman.math import weight_from_length 26 from spoolman.ws import websocket_manager 27 28 logger = logging.getLogger(__name__) 29 30 31 def utc_timezone_naive(dt: datetime) -> datetime: 32 """Convert a datetime object to UTC and remove timezone info.""" 33 return dt.astimezone(tz=timezone.utc).replace(tzinfo=None) 34 35 36 async def create( 37 *, 38 db: AsyncSession, 39 filament_id: int, 40 remaining_weight: float | None = None, 41 initial_weight: float | None = None, 42 spool_weight: float | None = None, 43 used_weight: float | None = None, 44 first_used: datetime | None = None, 45 last_used: datetime | None = None, 46 price: float | None = None, 47 location: str | None = None, 48 lot_nr: str | None = None, 49 comment: str | None = None, 50 archived: bool = False, 51 extra: dict[str, str] | None = None, 52 ) -> models.Spool: 53 """Add a new spool to the database. Leave weight empty to assume full spool.""" 54 filament_item = await filament.get_by_id(db, filament_id) 55 56 # Set spool_weight to spool_weight if spool_weight is not null and spool_weight not provided 57 if spool_weight is None and filament_item.spool_weight is not None: 58 spool_weight = filament_item.spool_weight 59 60 # Calculate initial_weight if not provided 61 if initial_weight is None and filament_item.weight is not None: 62 initial_weight = filament_item.weight 63 64 if used_weight is None: 65 if remaining_weight is not None: 66 if initial_weight is None or initial_weight == 0: 67 raise ItemCreateError( 68 "remaining_weight can only be used if the initial_weight is " 69 "defined or the filament has a weight set.", 70 ) 71 used_weight = max(initial_weight - remaining_weight, 0) 72 else: 73 used_weight = 0 74 75 # Convert datetime values to UTC and remove timezone info 76 if first_used is not None: 77 first_used = utc_timezone_naive(first_used) 78 if last_used is not None: 79 last_used = utc_timezone_naive(last_used) 80 81 spool = models.Spool( 82 filament=filament_item, 83 registered=datetime.utcnow().replace(microsecond=0), 84 initial_weight=initial_weight, 85 spool_weight=spool_weight, 86 used_weight=used_weight, 87 price=price, 88 first_used=first_used, 89 last_used=last_used, 90 location=location, 91 lot_nr=lot_nr, 92 comment=comment, 93 archived=archived, 94 extra=[models.SpoolField(key=k, value=v) for k, v in (extra or {}).items()], 95 ) 96 db.add(spool) 97 await db.commit() 98 await spool_changed(spool, EventType.ADDED) 99 return spool 100 101 102 async def get_by_id(db: AsyncSession, spool_id: int) -> models.Spool: 103 """Get a spool object from the database by the unique ID.""" 104 spool = await db.get( 105 models.Spool, 106 spool_id, 107 options=[joinedload("*")], # Load all nested objects as well 108 ) 109 if spool is None: 110 raise ItemNotFoundError(f"No spool with ID {spool_id} found.") 111 return spool 112 113 114 async def find( # noqa: C901, PLR0912 115 *, 116 db: AsyncSession, 117 filament_name: str | None = None, 118 filament_id: int | Sequence[int] | None = None, 119 filament_material: str | None = None, 120 vendor_name: str | None = None, 121 vendor_id: int | Sequence[int] | None = None, 122 location: str | None = None, 123 lot_nr: str | None = None, 124 allow_archived: bool = False, 125 sort_by: dict[str, SortOrder] | None = None, 126 limit: int | None = None, 127 offset: int = 0, 128 ) -> tuple[list[models.Spool], int]: 129 """Find a list of spool objects by search criteria. 130 131 Sort by a field by passing a dict with the field name as key and the sort order as value. 132 The field name can contain nested fields, e.g. filament.name. 133 134 Returns a tuple containing the list of items and the total count of matching items. 135 """ 136 stmt = ( 137 sqlalchemy.select(models.Spool) 138 .join(models.Spool.filament, isouter=True) 139 .join(models.Filament.vendor, isouter=True) 140 .options(contains_eager(models.Spool.filament).contains_eager(models.Filament.vendor)) 141 ) 142 143 stmt = add_where_clause_int(stmt, models.Spool.filament_id, filament_id) 144 stmt = add_where_clause_int_opt(stmt, models.Filament.vendor_id, vendor_id) 145 stmt = add_where_clause_str(stmt, models.Vendor.name, vendor_name) 146 stmt = add_where_clause_str_opt(stmt, models.Filament.name, filament_name) 147 stmt = add_where_clause_str_opt(stmt, models.Filament.material, filament_material) 148 stmt = add_where_clause_str_opt(stmt, models.Spool.location, location) 149 stmt = add_where_clause_str_opt(stmt, models.Spool.lot_nr, lot_nr) 150 151 if not allow_archived: 152 # Since the archived field is nullable, and default is false, we need to check for both false or null 153 stmt = stmt.where( 154 sqlalchemy.or_( 155 models.Spool.archived.is_(False), 156 models.Spool.archived.is_(None), 157 ), 158 ) 159 160 total_count = None 161 162 if limit is not None: 163 total_count_stmt = stmt.with_only_columns(func.count(), maintain_column_froms=True) 164 total_count = (await db.execute(total_count_stmt)).scalar() 165 166 stmt = stmt.offset(offset).limit(limit) 167 168 if sort_by is not None: 169 for fieldstr, order in sort_by.items(): 170 sorts = [] 171 if fieldstr == "remaining_weight": 172 sorts.append(coalesce(models.Spool.initial_weight, models.Filament.weight) - models.Spool.used_weight) 173 elif fieldstr == "remaining_length": 174 # Simplified weight -> length formula. Absolute value is not correct but the proportionality is still 175 # kept, which means the sort order is correct. 176 sorts.append( 177 (coalesce(models.Spool.initial_weight, models.Filament.weight) - models.Spool.used_weight) 178 / models.Filament.density 179 / (models.Filament.diameter * models.Filament.diameter), 180 ) 181 elif fieldstr == "used_length": 182 sorts.append( 183 models.Spool.used_weight 184 / models.Filament.density 185 / (models.Filament.diameter * models.Filament.diameter), 186 ) 187 elif fieldstr == "filament.combined_name": 188 sorts.append(models.Vendor.name) 189 sorts.append(models.Filament.name) 190 elif fieldstr == "price": 191 sorts.append(coalesce(models.Spool.price, models.Filament.price)) 192 else: 193 sorts.append(parse_nested_field(models.Spool, fieldstr)) 194 195 if order == SortOrder.ASC: 196 stmt = stmt.order_by(*(f.asc() for f in sorts)) 197 elif order == SortOrder.DESC: 198 stmt = stmt.order_by(*(f.desc() for f in sorts)) 199 200 rows = await db.execute( 201 stmt, 202 execution_options={"populate_existing": True}, 203 ) 204 result = list(rows.unique().scalars().all()) 205 if total_count is None: 206 total_count = len(result) 207 208 return result, total_count 209 210 211 async def update( 212 *, 213 db: AsyncSession, 214 spool_id: int, 215 data: dict, 216 ) -> models.Spool: 217 """Update the fields of a spool object.""" 218 spool = await get_by_id(db, spool_id) 219 for k, v in data.items(): 220 if k == "filament_id": 221 spool.filament = await filament.get_by_id(db, v) 222 # If there is no initial_weight, calculate it from the filament weight 223 if spool.initial_weight is None and spool.filament.weight is not None: 224 spool.initial_weight = spool.filament.weight 225 226 elif k == "remaining_weight": 227 if spool.initial_weight is None: 228 raise ItemCreateError("remaining_weight can only be used if initial_weight is set.") 229 spool.used_weight = max(spool.initial_weight - v, 0) 230 elif isinstance(v, datetime): 231 setattr(spool, k, utc_timezone_naive(v)) 232 elif k == "extra": 233 spool.extra = [f for f in spool.extra if f.key not in v] 234 spool.extra.extend([models.SpoolField(key=k, value=v) for k, v in v.items()]) 235 else: 236 setattr(spool, k, v) 237 await db.commit() 238 await spool_changed(spool, EventType.UPDATED) 239 return spool 240 241 242 async def delete(db: AsyncSession, spool_id: int) -> None: 243 """Delete a spool object.""" 244 spool = await get_by_id(db, spool_id) 245 await spool_changed(spool, EventType.DELETED) 246 await db.delete(spool) 247 248 249 async def clear_extra_field(db: AsyncSession, key: str) -> None: 250 """Delete all extra fields with a specific key.""" 251 await db.execute( 252 sqlalchemy.delete(models.SpoolField).where(models.SpoolField.key == key), 253 ) 254 255 256 async def use_weight_safe(db: AsyncSession, spool_id: int, weight: float) -> None: 257 """Consume filament from a spool by weight in a way that is safe against race conditions. 258 259 Args: 260 db (AsyncSession): Database session 261 spool_id (int): Spool ID 262 weight (float): Filament weight to consume, in grams 263 264 """ 265 await db.execute( 266 sqlalchemy.update(models.Spool) 267 .where(models.Spool.id == spool_id) 268 .values( 269 used_weight=case( 270 (models.Spool.used_weight + weight >= 0.0, models.Spool.used_weight + weight), 271 else_=0.0, # Set used_weight to 0 if the result would be negative 272 ), 273 ), 274 ) 275 276 277 async def use_weight(db: AsyncSession, spool_id: int, weight: float) -> models.Spool: 278 """Consume filament from a spool by weight. 279 280 Increases the used_weight attribute of the spool. 281 Updates the first_used and last_used attributes where appropriate. 282 283 Args: 284 db (AsyncSession): Database session 285 spool_id (int): Spool ID 286 weight (float): Filament weight to consume, in grams 287 288 Returns: 289 models.Spool: Updated spool object 290 291 """ 292 await use_weight_safe(db, spool_id, weight) 293 294 spool = await get_by_id(db, spool_id) 295 296 if spool.first_used is None: 297 spool.first_used = datetime.utcnow().replace(microsecond=0) 298 spool.last_used = datetime.utcnow().replace(microsecond=0) 299 300 await db.commit() 301 await spool_changed(spool, EventType.UPDATED) 302 return spool 303 304 305 async def use_length(db: AsyncSession, spool_id: int, length: float) -> models.Spool: 306 """Consume filament from a spool by length. 307 308 Increases the used_weight attribute of the spool. 309 Updates the first_used and last_used attributes where appropriate. 310 311 Args: 312 db (AsyncSession): Database session 313 spool_id (int): Spool ID 314 length (float): Length of filament to consume, in mm 315 316 Returns: 317 models.Spool: Updated spool object 318 319 """ 320 # Get filament diameter and density 321 result = await db.execute( 322 sqlalchemy.select(models.Filament.diameter, models.Filament.density) 323 .join(models.Spool, models.Spool.filament_id == models.Filament.id) 324 .where(models.Spool.id == spool_id), 325 ) 326 try: 327 filament_info = result.one() 328 except NoResultFound as exc: 329 raise ItemNotFoundError("Filament not found for spool.") from exc 330 331 # Calculate and use weight 332 weight = weight_from_length( 333 length=length, 334 diameter=filament_info[0], 335 density=filament_info[1], 336 ) 337 await use_weight_safe(db, spool_id, weight) 338 339 # Get spool with new weight and update first_used and last_used 340 spool = await get_by_id(db, spool_id) 341 342 if spool.first_used is None: 343 spool.first_used = datetime.utcnow().replace(microsecond=0) 344 spool.last_used = datetime.utcnow().replace(microsecond=0) 345 346 await db.commit() 347 await spool_changed(spool, EventType.UPDATED) 348 return spool 349 350 351 async def measure(db: AsyncSession, spool_id: int, weight: float) -> models.Spool: 352 """Record usage based on current gross weight of spool. 353 354 Increases the used_weight attribute of the spool. 355 Updates the first_used and last_used attributes where appropriate. 356 357 Args: 358 db (AsyncSession): Database session 359 spool_id (int): Spool ID 360 weight (float): Length of filament to consume, in mm 361 362 Returns: 363 models.Spool: Updated spool object 364 365 """ 366 spool_result = await db.execute( 367 sqlalchemy.select(models.Spool.initial_weight, models.Spool.used_weight, models.Spool.spool_weight).where( 368 models.Spool.id == spool_id, 369 ), 370 ) 371 372 try: 373 spool_info = spool_result.one() 374 except NoResultFound as exc: 375 raise SpoolMeasureError("Spool not found.") from exc 376 377 initial_weight = spool_info[0] 378 spool_weight = spool_info[2] 379 if initial_weight is None or initial_weight == 0 or spool_weight is None or spool_weight == 0: 380 # Get filament weight and spool_weight 381 result = await db.execute( 382 sqlalchemy.select(models.Filament.weight, models.Filament.spool_weight) 383 .join(models.Spool, models.Spool.filament_id == models.Filament.id) 384 .where(models.Spool.id == spool_id), 385 ) 386 try: 387 filament_info = result.one() 388 except NoResultFound as exc: 389 raise ItemNotFoundError("Filament not found for spool.") from exc 390 391 if spool_weight is None or spool_weight == 0: 392 spool_weight = filament_info[1] 393 394 if initial_weight is None or initial_weight == 0: 395 initial_weight = filament_info[0] if filament_info[0] is not None else 0 396 397 if initial_weight is None or initial_weight == 0: 398 raise SpoolMeasureError("Initial weight is not set.") 399 400 initial_gross_weight = initial_weight + spool_weight 401 402 # if the measurement is greater than the initial weight, set the initial weight to the measurement 403 if weight > initial_gross_weight: 404 return await reset_initial_weight(db, spool_id, weight - spool_weight) 405 406 # Calculate the current net weight 407 current_use = initial_gross_weight - spool_info[1] 408 409 # Calculate the weight used since last measure 410 weight_to_use = current_use - weight 411 412 # If the measured weight is less than the empty weight, use the rest of the spool 413 if (initial_gross_weight - weight_to_use) < spool_weight: 414 weight_to_use = current_use - spool_weight 415 416 return await use_weight(db, spool_id, weight_to_use) 417 418 419 async def find_locations( 420 *, 421 db: AsyncSession, 422 ) -> list[str]: 423 """Find a list of spool locations by searching for distinct values in the spool table.""" 424 stmt = sqlalchemy.select(models.Spool.location).distinct() 425 rows = await db.execute(stmt) 426 return [row[0] for row in rows.all() if row[0] is not None] 427 428 429 async def find_lot_numbers( 430 *, 431 db: AsyncSession, 432 ) -> list[str]: 433 """Find a list of spool lot numbers by searching for distinct values in the spool table.""" 434 stmt = sqlalchemy.select(models.Spool.lot_nr).distinct() 435 rows = await db.execute(stmt) 436 return [row[0] for row in rows.all() if row[0] is not None] 437 438 439 async def spool_changed(spool: models.Spool, typ: EventType) -> None: 440 """Notify websocket clients that a spool has changed.""" 441 try: 442 await websocket_manager.send( 443 ("spool", str(spool.id)), 444 SpoolEvent( 445 type=typ, 446 resource="spool", 447 date=datetime.utcnow(), 448 payload=Spool.from_db(spool), 449 ), 450 ) 451 except Exception: 452 # Important to have a catch-all here since we don't want to stop the call if this fails. 453 logger.exception("Failed to send websocket message") 454 455 456 async def reset_initial_weight(db: AsyncSession, spool_id: int, weight: float) -> models.Spool: 457 """Reset inital weight to new weight and used_weight to 0.""" 458 spool = await get_by_id(db, spool_id) 459 460 spool.initial_weight = weight 461 spool.used_weight = 0 462 await db.commit() 463 await spool_changed(spool, EventType.UPDATED) 464 return spool 465 466 467 async def rename_location( 468 *, 469 db: AsyncSession, 470 current_name: str, 471 new_name: str, 472 ) -> None: 473 """Rename all spools with the current location name to the new name.""" 474 await db.execute( 475 sqlalchemy.update(models.Spool).where(models.Spool.location == current_name).values(location=new_name), 476 )