/ adafruit_wave.py
adafruit_wave.py
1 # SPDX-FileCopyrightText: 2023 Guido van Rossum <guido@cwi.nl> and others. 2 # 3 # SPDX-License-Identifier: PSF-2.0 4 """ 5 `adafruit_wave` 6 ================================================================================ 7 8 Read and write standard WAV-format files 9 10 11 * Author(s): Jeff Epler 12 13 Implementation Notes 14 -------------------- 15 16 **Software and Dependencies:** 17 18 * Adafruit CircuitPython firmware for the supported boards: 19 https://circuitpython.org/downloads 20 21 """ 22 23 # pylint: disable=missing-class-docstring,redefined-outer-name,missing-function-docstring,invalid-name,import-outside-toplevel,too-many-instance-attributes,consider-using-with,no-self-use,redefined-builtin,not-callable,unused-variable,attribute-defined-outside-init,too-many-public-methods,no-else-return 24 # imports 25 26 __version__ = "0.0.0+auto.0" 27 __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_wave.git" 28 29 from collections import namedtuple 30 import builtins 31 import struct 32 33 34 __all__ = ["open", "Error", "Wave_read", "Wave_write"] 35 36 37 class Chunk: 38 def __init__(self, file, align=True, bigendian=True, inclheader=False): 39 self.closed = False 40 self.align = align # whether to align to word (2-byte) boundaries 41 if bigendian: 42 strflag = ">" 43 else: 44 strflag = "<" 45 self.file = file 46 self.chunkname = file.read(4) 47 if len(self.chunkname) < 4: 48 raise EOFError 49 try: 50 self.chunksize = struct.unpack_from(strflag + "L", file.read(4))[0] 51 except struct.error: 52 raise EOFError from None 53 if inclheader: 54 self.chunksize = self.chunksize - 8 # subtract header 55 self.size_read = 0 56 try: 57 self.offset = self.file.tell() 58 except (AttributeError, OSError): 59 self.seekable = False 60 else: 61 self.seekable = True 62 63 def getname(self): 64 """Return the name (ID) of the current chunk.""" 65 return self.chunkname 66 67 def getsize(self): 68 """Return the size of the current chunk.""" 69 return self.chunksize 70 71 def close(self): 72 if not self.closed: 73 try: 74 self.skip() 75 finally: 76 self.closed = True 77 78 def isatty(self): 79 if self.closed: 80 raise ValueError("I/O operation on closed file") 81 return False 82 83 def seek(self, pos, whence=0): 84 """Seek to specified position into the chunk. 85 Default position is 0 (start of chunk). 86 If the file is not seekable, this will result in an error. 87 """ 88 89 if self.closed: 90 raise ValueError("I/O operation on closed file") 91 if not self.seekable: 92 raise OSError("cannot seek") 93 if whence == 1: 94 pos = pos + self.size_read 95 elif whence == 2: 96 pos = pos + self.chunksize 97 if pos < 0 or pos > self.chunksize: 98 raise RuntimeError 99 self.file.seek(self.offset + pos, 0) 100 self.size_read = pos 101 102 def tell(self): 103 if self.closed: 104 raise ValueError("I/O operation on closed file") 105 return self.size_read 106 107 def read(self, size=-1): 108 """Read at most size bytes from the chunk. 109 110 If size is omitted or negative, read until the end 111 of the chunk. 112 """ 113 114 if self.closed: 115 raise ValueError("I/O operation on closed file") 116 if self.size_read >= self.chunksize: 117 return b"" 118 if size < 0: 119 size = self.chunksize - self.size_read 120 if size > self.chunksize - self.size_read: 121 size = self.chunksize - self.size_read 122 data = self.file.read(size) 123 self.size_read = self.size_read + len(data) 124 if self.size_read == self.chunksize and self.align and (self.chunksize & 1): 125 dummy = self.file.read(1) 126 self.size_read = self.size_read + len(dummy) 127 return data 128 129 def skip(self): 130 """Skip the rest of the chunk. 131 132 If you are not interested in the contents of the chunk, 133 this method should be called so that the file points to 134 the start of the next chunk. 135 """ 136 137 if self.closed: 138 raise ValueError("I/O operation on closed file") 139 if self.seekable: 140 try: 141 n = self.chunksize - self.size_read 142 # maybe fix alignment 143 if self.align and (self.chunksize & 1): 144 n = n + 1 145 self.file.seek(n, 1) 146 self.size_read = self.size_read + n 147 return 148 except OSError: 149 pass 150 while self.size_read < self.chunksize: 151 n = min(8192, self.chunksize - self.size_read) 152 dummy = self.read(n) 153 if not dummy: 154 raise EOFError 155 156 157 class Error(Exception): 158 pass 159 160 161 WAVE_FORMAT_PCM = 0x0001 162 163 _array_fmts = None, "b", "h", None, "i" 164 165 _wave_params = namedtuple( 166 "_wave_params", "nchannels sampwidth framerate nframes comptype compname" 167 ) 168 169 170 class Wave_read: 171 """Used for wave files opened in read mode. 172 173 Do not construct directly, but call `open` instead.""" 174 175 def initfp(self, file): 176 self._convert = None 177 self._soundpos = 0 178 self._file = Chunk(file, bigendian=0) 179 if self._file.getname() != b"RIFF": 180 raise Error("file does not start with RIFF id") 181 if self._file.read(4) != b"WAVE": 182 raise Error("not a WAVE file") 183 self._fmt_chunk_read = 0 184 self._data_chunk = None 185 while 1: 186 self._data_seek_needed = 1 187 try: 188 chunk = Chunk(self._file, bigendian=0) 189 except EOFError: 190 break 191 chunkname = chunk.getname() 192 if chunkname == b"fmt ": 193 self._read_fmt_chunk(chunk) 194 self._fmt_chunk_read = 1 195 elif chunkname == b"data": 196 if not self._fmt_chunk_read: 197 raise Error("data chunk before fmt chunk") 198 self._data_chunk = chunk 199 self._nframes = chunk.chunksize // self._framesize 200 self._data_seek_needed = 0 201 break 202 chunk.skip() 203 if not self._fmt_chunk_read or not self._data_chunk: 204 raise Error("fmt chunk and/or data chunk missing") 205 206 def __init__(self, f): 207 self._i_opened_the_file = None 208 if isinstance(f, str): 209 f = builtins.open(f, "rb") 210 self._i_opened_the_file = f 211 # else, assume it is an open file object already 212 try: 213 self.initfp(f) 214 except: 215 if self._i_opened_the_file: 216 f.close() 217 raise 218 219 def __del__(self): 220 self.close() 221 222 def __enter__(self): 223 return self 224 225 def __exit__(self, *args): 226 self.close() 227 228 # 229 # User visible methods. 230 # 231 def getfp(self): 232 """Get the underlying file object""" 233 return self._file 234 235 def rewind(self): 236 """Go back to the start of the audio data""" 237 self._data_seek_needed = 1 238 self._soundpos = 0 239 240 def close(self): 241 """Close the file""" 242 self._file = None 243 file = self._i_opened_the_file 244 if file: 245 self._i_opened_the_file = None 246 file.close() 247 248 def tell(self): 249 """Get the current position in the audio data""" 250 return self._soundpos 251 252 def getnchannels(self): 253 """Get the number of channels (1 for mono, 2 for stereo)""" 254 return self._nchannels 255 256 def getnframes(self): 257 """Get the number of frames""" 258 return self._nframes 259 260 def getsampwidth(self): 261 """Get the sample width in bytes""" 262 return self._sampwidth 263 264 def getframerate(self): 265 """Get the sample rate in Hz""" 266 return self._framerate 267 268 def setpos(self, pos): 269 """Seek to a particular position in the audio data""" 270 if pos < 0 or pos > self._nframes: 271 raise Error("position not in range") 272 self._soundpos = pos 273 self._data_seek_needed = 1 274 275 def readframes(self, nframes): 276 """Read frames of audio data""" 277 if self._data_seek_needed: 278 self._data_chunk.seek(0, 0) 279 pos = self._soundpos * self._framesize 280 if pos: 281 self._data_chunk.seek(pos, 0) 282 self._data_seek_needed = 0 283 if nframes == 0: 284 return b"" 285 data = self._data_chunk.read(nframes * self._framesize) 286 if self._convert and data: 287 data = self._convert(data) 288 self._soundpos = self._soundpos + len(data) // ( 289 self._nchannels * self._sampwidth 290 ) 291 return data 292 293 # 294 # Internal methods. 295 # 296 297 def _read_fmt_chunk(self, chunk): 298 try: 299 ( 300 wFormatTag, 301 self._nchannels, 302 self._framerate, 303 dwAvgBytesPerSec, 304 wBlockAlign, 305 ) = struct.unpack_from("<HHLLH", chunk.read(14)) 306 except struct.error: 307 raise EOFError from None 308 if wFormatTag == WAVE_FORMAT_PCM: 309 try: 310 sampwidth = struct.unpack_from("<H", chunk.read(2))[0] 311 except struct.error: 312 raise EOFError from None 313 self._sampwidth = (sampwidth + 7) // 8 314 if not self._sampwidth: 315 raise Error("bad sample width") 316 else: 317 raise Error("unknown format: %r" % (wFormatTag,)) 318 if not self._nchannels: 319 raise Error("bad # of channels") 320 self._framesize = self._nchannels * self._sampwidth 321 self._comptype = "NONE" 322 self._compname = "not compressed" 323 324 325 class Wave_write: 326 """Used for wave files opened in write mode. 327 328 Do not construct directly, but call `open` instead.""" 329 330 def __init__(self, f): 331 self._i_opened_the_file = None 332 if isinstance(f, str): 333 f = builtins.open(f, "wb") 334 self._i_opened_the_file = f 335 try: 336 self.initfp(f) 337 except: 338 if self._i_opened_the_file: 339 f.close() 340 raise 341 342 def initfp(self, file): 343 self._file = file 344 self._convert = None 345 self._nchannels = 0 346 self._sampwidth = 0 347 self._framerate = 0 348 self._nframes = 0 349 self._nframeswritten = 0 350 self._datawritten = 0 351 self._datalength = 0 352 self._headerwritten = False 353 354 def __del__(self): 355 self.close() 356 357 def __enter__(self): 358 return self 359 360 def __exit__(self, *args): 361 self.close() 362 363 # 364 # User visible methods. 365 # 366 def setnchannels(self, nchannels): 367 """Set the number of channels (1 for mono, 2 for stereo)""" 368 if self._datawritten: 369 raise Error("cannot change parameters after starting to write") 370 if nchannels < 1: 371 raise Error("bad # of channels") 372 self._nchannels = nchannels 373 374 def getnchannels(self): 375 """Get the number of channels (1 for mono, 2 for stereo)""" 376 if not self._nchannels: 377 raise Error("number of channels not set") 378 return self._nchannels 379 380 def setsampwidth(self, sampwidth): 381 """Set the sample width in bytes""" 382 if self._datawritten: 383 raise Error("cannot change parameters after starting to write") 384 if sampwidth < 1 or sampwidth > 4: 385 raise Error("bad sample width") 386 self._sampwidth = sampwidth 387 388 def getsampwidth(self): 389 """Get the sample width in bytes""" 390 if not self._sampwidth: 391 raise Error("sample width not set") 392 return self._sampwidth 393 394 def setframerate(self, framerate): 395 """Set the sample rate in Hz""" 396 if self._datawritten: 397 raise Error("cannot change parameters after starting to write") 398 if framerate <= 0: 399 raise Error("bad frame rate") 400 self._framerate = int(round(framerate)) 401 402 def getframerate(self): 403 """Get the sample rate in Hz""" 404 if not self._framerate: 405 raise Error("frame rate not set") 406 return self._framerate 407 408 def setnframes(self, nframes): 409 if self._datawritten: 410 raise Error("cannot change parameters after starting to write") 411 self._nframes = nframes 412 413 def getnframes(self): 414 return self._nframeswritten 415 416 def setparams(self, params): 417 """Set all parameters at once""" 418 nchannels, sampwidth, framerate, nframes, comptype, compname = params 419 if self._datawritten: 420 raise Error("cannot change parameters after starting to write") 421 self.setnchannels(nchannels) 422 self.setsampwidth(sampwidth) 423 self.setframerate(framerate) 424 self.setnframes(nframes) 425 426 def tell(self): 427 """Get the current position in the audio data""" 428 return self._nframeswritten 429 430 def writeframesraw(self, data): 431 """Write data to the file without updating the header""" 432 if not isinstance(data, (bytes, bytearray)): 433 data = memoryview(data).cast("B") 434 self._ensure_header_written(len(data)) 435 nframes = len(data) // (self._sampwidth * self._nchannels) 436 if self._convert: 437 data = self._convert(data) 438 self._file.write(data) 439 self._datawritten += len(data) 440 self._nframeswritten = self._nframeswritten + nframes 441 442 def writeframes(self, data): 443 """Write data to the file and update the header if needed""" 444 self.writeframesraw(data) 445 if self._datalength != self._datawritten: 446 self._patchheader() 447 448 def close(self): 449 """Close the file""" 450 try: 451 if self._file: 452 self._ensure_header_written(0) 453 if self._datalength != self._datawritten: 454 self._patchheader() 455 self._file.flush() 456 finally: 457 self._file = None 458 file = self._i_opened_the_file 459 if file: 460 self._i_opened_the_file = None 461 file.close() 462 463 # 464 # Internal methods. 465 # 466 467 def _ensure_header_written(self, datasize): 468 if not self._headerwritten: 469 if not self._nchannels: 470 raise Error("# channels not specified") 471 if not self._sampwidth: 472 raise Error("sample width not specified") 473 if not self._framerate: 474 raise Error("sampling rate not specified") 475 self._write_header(datasize) 476 477 def _write_header(self, initlength): 478 assert not self._headerwritten 479 self._file.write(b"RIFF") 480 if not self._nframes: 481 self._nframes = initlength // (self._nchannels * self._sampwidth) 482 self._datalength = self._nframes * self._nchannels * self._sampwidth 483 try: 484 self._form_length_pos = self._file.tell() 485 except (AttributeError, OSError): 486 self._form_length_pos = None 487 self._file.write( 488 struct.pack( 489 "<L4s4sLHHLLHH4s", 490 36 + self._datalength, 491 b"WAVE", 492 b"fmt ", 493 16, 494 WAVE_FORMAT_PCM, 495 self._nchannels, 496 self._framerate, 497 self._nchannels * self._framerate * self._sampwidth, 498 self._nchannels * self._sampwidth, 499 self._sampwidth * 8, 500 b"data", 501 ) 502 ) 503 if self._form_length_pos is not None: 504 self._data_length_pos = self._file.tell() 505 self._file.write(struct.pack("<L", self._datalength)) 506 self._headerwritten = True 507 508 def _patchheader(self): 509 assert self._headerwritten 510 if self._datawritten == self._datalength: 511 return 512 curpos = self._file.tell() 513 self._file.seek(self._form_length_pos, 0) 514 self._file.write(struct.pack("<L", 36 + self._datawritten)) 515 self._file.seek(self._data_length_pos, 0) 516 self._file.write(struct.pack("<L", self._datawritten)) 517 self._file.seek(curpos, 0) 518 self._datalength = self._datawritten 519 520 521 def open(f, mode=None): # pylint: disable=redefined-builtin 522 """Open a wave file in reading (default) or writing (``mode="w"``) mode. 523 524 The argument may be a filename or an open file. 525 526 In reading mode, returns a `Wave_read` object. 527 In writing mode, returns a `Wave_write` object. 528 """ 529 if mode is None: 530 if hasattr(f, "mode"): 531 mode = f.mode 532 else: 533 mode = "rb" 534 if mode in ("r", "rb"): 535 return Wave_read(f) 536 elif mode in ("w", "wb"): 537 return Wave_write(f) 538 else: 539 raise Error("mode must be 'r', 'rb', 'w', or 'wb'")