/ adafruit_wave.py
adafruit_wave.py
  1  # SPDX-FileCopyrightText: 2023 Guido van Rossum <guido@cwi.nl> and others.
  2  #
  3  # SPDX-License-Identifier: PSF-2.0
  4  """
  5  `adafruit_wave`
  6  ================================================================================
  7  
  8  Read and write standard WAV-format files
  9  
 10  
 11  * Author(s): Jeff Epler
 12  
 13  Implementation Notes
 14  --------------------
 15  
 16  **Software and Dependencies:**
 17  
 18  * Adafruit CircuitPython firmware for the supported boards:
 19    https://circuitpython.org/downloads
 20  
 21  """
 22  
 23  # pylint: disable=missing-class-docstring,redefined-outer-name,missing-function-docstring,invalid-name,import-outside-toplevel,too-many-instance-attributes,consider-using-with,no-self-use,redefined-builtin,not-callable,unused-variable,attribute-defined-outside-init,too-many-public-methods,no-else-return
 24  # imports
 25  
 26  __version__ = "0.0.0+auto.0"
 27  __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_wave.git"
 28  
 29  from collections import namedtuple
 30  import builtins
 31  import struct
 32  
 33  
 34  __all__ = ["open", "Error", "Wave_read", "Wave_write"]
 35  
 36  
 37  class Chunk:
 38      def __init__(self, file, align=True, bigendian=True, inclheader=False):
 39          self.closed = False
 40          self.align = align  # whether to align to word (2-byte) boundaries
 41          if bigendian:
 42              strflag = ">"
 43          else:
 44              strflag = "<"
 45          self.file = file
 46          self.chunkname = file.read(4)
 47          if len(self.chunkname) < 4:
 48              raise EOFError
 49          try:
 50              self.chunksize = struct.unpack_from(strflag + "L", file.read(4))[0]
 51          except struct.error:
 52              raise EOFError from None
 53          if inclheader:
 54              self.chunksize = self.chunksize - 8  # subtract header
 55          self.size_read = 0
 56          try:
 57              self.offset = self.file.tell()
 58          except (AttributeError, OSError):
 59              self.seekable = False
 60          else:
 61              self.seekable = True
 62  
 63      def getname(self):
 64          """Return the name (ID) of the current chunk."""
 65          return self.chunkname
 66  
 67      def getsize(self):
 68          """Return the size of the current chunk."""
 69          return self.chunksize
 70  
 71      def close(self):
 72          if not self.closed:
 73              try:
 74                  self.skip()
 75              finally:
 76                  self.closed = True
 77  
 78      def isatty(self):
 79          if self.closed:
 80              raise ValueError("I/O operation on closed file")
 81          return False
 82  
 83      def seek(self, pos, whence=0):
 84          """Seek to specified position into the chunk.
 85          Default position is 0 (start of chunk).
 86          If the file is not seekable, this will result in an error.
 87          """
 88  
 89          if self.closed:
 90              raise ValueError("I/O operation on closed file")
 91          if not self.seekable:
 92              raise OSError("cannot seek")
 93          if whence == 1:
 94              pos = pos + self.size_read
 95          elif whence == 2:
 96              pos = pos + self.chunksize
 97          if pos < 0 or pos > self.chunksize:
 98              raise RuntimeError
 99          self.file.seek(self.offset + pos, 0)
100          self.size_read = pos
101  
102      def tell(self):
103          if self.closed:
104              raise ValueError("I/O operation on closed file")
105          return self.size_read
106  
107      def read(self, size=-1):
108          """Read at most size bytes from the chunk.
109  
110          If size is omitted or negative, read until the end
111          of the chunk.
112          """
113  
114          if self.closed:
115              raise ValueError("I/O operation on closed file")
116          if self.size_read >= self.chunksize:
117              return b""
118          if size < 0:
119              size = self.chunksize - self.size_read
120          if size > self.chunksize - self.size_read:
121              size = self.chunksize - self.size_read
122          data = self.file.read(size)
123          self.size_read = self.size_read + len(data)
124          if self.size_read == self.chunksize and self.align and (self.chunksize & 1):
125              dummy = self.file.read(1)
126              self.size_read = self.size_read + len(dummy)
127          return data
128  
129      def skip(self):
130          """Skip the rest of the chunk.
131  
132          If you are not interested in the contents of the chunk,
133          this method should be called so that the file points to
134          the start of the next chunk.
135          """
136  
137          if self.closed:
138              raise ValueError("I/O operation on closed file")
139          if self.seekable:
140              try:
141                  n = self.chunksize - self.size_read
142                  # maybe fix alignment
143                  if self.align and (self.chunksize & 1):
144                      n = n + 1
145                  self.file.seek(n, 1)
146                  self.size_read = self.size_read + n
147                  return
148              except OSError:
149                  pass
150          while self.size_read < self.chunksize:
151              n = min(8192, self.chunksize - self.size_read)
152              dummy = self.read(n)
153              if not dummy:
154                  raise EOFError
155  
156  
157  class Error(Exception):
158      pass
159  
160  
161  WAVE_FORMAT_PCM = 0x0001
162  
163  _array_fmts = None, "b", "h", None, "i"
164  
165  _wave_params = namedtuple(
166      "_wave_params", "nchannels sampwidth framerate nframes comptype compname"
167  )
168  
169  
170  class Wave_read:
171      """Used for wave files opened in read mode.
172  
173      Do not construct directly, but call `open` instead."""
174  
175      def initfp(self, file):
176          self._convert = None
177          self._soundpos = 0
178          self._file = Chunk(file, bigendian=0)
179          if self._file.getname() != b"RIFF":
180              raise Error("file does not start with RIFF id")
181          if self._file.read(4) != b"WAVE":
182              raise Error("not a WAVE file")
183          self._fmt_chunk_read = 0
184          self._data_chunk = None
185          while 1:
186              self._data_seek_needed = 1
187              try:
188                  chunk = Chunk(self._file, bigendian=0)
189              except EOFError:
190                  break
191              chunkname = chunk.getname()
192              if chunkname == b"fmt ":
193                  self._read_fmt_chunk(chunk)
194                  self._fmt_chunk_read = 1
195              elif chunkname == b"data":
196                  if not self._fmt_chunk_read:
197                      raise Error("data chunk before fmt chunk")
198                  self._data_chunk = chunk
199                  self._nframes = chunk.chunksize // self._framesize
200                  self._data_seek_needed = 0
201                  break
202              chunk.skip()
203          if not self._fmt_chunk_read or not self._data_chunk:
204              raise Error("fmt chunk and/or data chunk missing")
205  
206      def __init__(self, f):
207          self._i_opened_the_file = None
208          if isinstance(f, str):
209              f = builtins.open(f, "rb")
210              self._i_opened_the_file = f
211          # else, assume it is an open file object already
212          try:
213              self.initfp(f)
214          except:
215              if self._i_opened_the_file:
216                  f.close()
217              raise
218  
219      def __del__(self):
220          self.close()
221  
222      def __enter__(self):
223          return self
224  
225      def __exit__(self, *args):
226          self.close()
227  
228      #
229      # User visible methods.
230      #
231      def getfp(self):
232          """Get the underlying file object"""
233          return self._file
234  
235      def rewind(self):
236          """Go back to the start of the audio data"""
237          self._data_seek_needed = 1
238          self._soundpos = 0
239  
240      def close(self):
241          """Close the file"""
242          self._file = None
243          file = self._i_opened_the_file
244          if file:
245              self._i_opened_the_file = None
246              file.close()
247  
248      def tell(self):
249          """Get the current position in the audio data"""
250          return self._soundpos
251  
252      def getnchannels(self):
253          """Get the number of channels (1 for mono, 2 for stereo)"""
254          return self._nchannels
255  
256      def getnframes(self):
257          """Get the number of frames"""
258          return self._nframes
259  
260      def getsampwidth(self):
261          """Get the sample width in bytes"""
262          return self._sampwidth
263  
264      def getframerate(self):
265          """Get the sample rate in Hz"""
266          return self._framerate
267  
268      def setpos(self, pos):
269          """Seek to a particular position in the audio data"""
270          if pos < 0 or pos > self._nframes:
271              raise Error("position not in range")
272          self._soundpos = pos
273          self._data_seek_needed = 1
274  
275      def readframes(self, nframes):
276          """Read frames of audio data"""
277          if self._data_seek_needed:
278              self._data_chunk.seek(0, 0)
279              pos = self._soundpos * self._framesize
280              if pos:
281                  self._data_chunk.seek(pos, 0)
282              self._data_seek_needed = 0
283          if nframes == 0:
284              return b""
285          data = self._data_chunk.read(nframes * self._framesize)
286          if self._convert and data:
287              data = self._convert(data)
288          self._soundpos = self._soundpos + len(data) // (
289              self._nchannels * self._sampwidth
290          )
291          return data
292  
293      #
294      # Internal methods.
295      #
296  
297      def _read_fmt_chunk(self, chunk):
298          try:
299              (
300                  wFormatTag,
301                  self._nchannels,
302                  self._framerate,
303                  dwAvgBytesPerSec,
304                  wBlockAlign,
305              ) = struct.unpack_from("<HHLLH", chunk.read(14))
306          except struct.error:
307              raise EOFError from None
308          if wFormatTag == WAVE_FORMAT_PCM:
309              try:
310                  sampwidth = struct.unpack_from("<H", chunk.read(2))[0]
311              except struct.error:
312                  raise EOFError from None
313              self._sampwidth = (sampwidth + 7) // 8
314              if not self._sampwidth:
315                  raise Error("bad sample width")
316          else:
317              raise Error("unknown format: %r" % (wFormatTag,))
318          if not self._nchannels:
319              raise Error("bad # of channels")
320          self._framesize = self._nchannels * self._sampwidth
321          self._comptype = "NONE"
322          self._compname = "not compressed"
323  
324  
325  class Wave_write:
326      """Used for wave files opened in write mode.
327  
328      Do not construct directly, but call `open` instead."""
329  
330      def __init__(self, f):
331          self._i_opened_the_file = None
332          if isinstance(f, str):
333              f = builtins.open(f, "wb")
334              self._i_opened_the_file = f
335          try:
336              self.initfp(f)
337          except:
338              if self._i_opened_the_file:
339                  f.close()
340              raise
341  
342      def initfp(self, file):
343          self._file = file
344          self._convert = None
345          self._nchannels = 0
346          self._sampwidth = 0
347          self._framerate = 0
348          self._nframes = 0
349          self._nframeswritten = 0
350          self._datawritten = 0
351          self._datalength = 0
352          self._headerwritten = False
353  
354      def __del__(self):
355          self.close()
356  
357      def __enter__(self):
358          return self
359  
360      def __exit__(self, *args):
361          self.close()
362  
363      #
364      # User visible methods.
365      #
366      def setnchannels(self, nchannels):
367          """Set the number of channels (1 for mono, 2 for stereo)"""
368          if self._datawritten:
369              raise Error("cannot change parameters after starting to write")
370          if nchannels < 1:
371              raise Error("bad # of channels")
372          self._nchannels = nchannels
373  
374      def getnchannels(self):
375          """Get the number of channels (1 for mono, 2 for stereo)"""
376          if not self._nchannels:
377              raise Error("number of channels not set")
378          return self._nchannels
379  
380      def setsampwidth(self, sampwidth):
381          """Set the sample width in bytes"""
382          if self._datawritten:
383              raise Error("cannot change parameters after starting to write")
384          if sampwidth < 1 or sampwidth > 4:
385              raise Error("bad sample width")
386          self._sampwidth = sampwidth
387  
388      def getsampwidth(self):
389          """Get the sample width in bytes"""
390          if not self._sampwidth:
391              raise Error("sample width not set")
392          return self._sampwidth
393  
394      def setframerate(self, framerate):
395          """Set the sample rate in Hz"""
396          if self._datawritten:
397              raise Error("cannot change parameters after starting to write")
398          if framerate <= 0:
399              raise Error("bad frame rate")
400          self._framerate = int(round(framerate))
401  
402      def getframerate(self):
403          """Get the sample rate in Hz"""
404          if not self._framerate:
405              raise Error("frame rate not set")
406          return self._framerate
407  
408      def setnframes(self, nframes):
409          if self._datawritten:
410              raise Error("cannot change parameters after starting to write")
411          self._nframes = nframes
412  
413      def getnframes(self):
414          return self._nframeswritten
415  
416      def setparams(self, params):
417          """Set all parameters at once"""
418          nchannels, sampwidth, framerate, nframes, comptype, compname = params
419          if self._datawritten:
420              raise Error("cannot change parameters after starting to write")
421          self.setnchannels(nchannels)
422          self.setsampwidth(sampwidth)
423          self.setframerate(framerate)
424          self.setnframes(nframes)
425  
426      def tell(self):
427          """Get the current position in the audio data"""
428          return self._nframeswritten
429  
430      def writeframesraw(self, data):
431          """Write data to the file without updating the header"""
432          if not isinstance(data, (bytes, bytearray)):
433              data = memoryview(data).cast("B")
434          self._ensure_header_written(len(data))
435          nframes = len(data) // (self._sampwidth * self._nchannels)
436          if self._convert:
437              data = self._convert(data)
438          self._file.write(data)
439          self._datawritten += len(data)
440          self._nframeswritten = self._nframeswritten + nframes
441  
442      def writeframes(self, data):
443          """Write data to the file and update the header if needed"""
444          self.writeframesraw(data)
445          if self._datalength != self._datawritten:
446              self._patchheader()
447  
448      def close(self):
449          """Close the file"""
450          try:
451              if self._file:
452                  self._ensure_header_written(0)
453                  if self._datalength != self._datawritten:
454                      self._patchheader()
455                  self._file.flush()
456          finally:
457              self._file = None
458              file = self._i_opened_the_file
459              if file:
460                  self._i_opened_the_file = None
461                  file.close()
462  
463      #
464      # Internal methods.
465      #
466  
467      def _ensure_header_written(self, datasize):
468          if not self._headerwritten:
469              if not self._nchannels:
470                  raise Error("# channels not specified")
471              if not self._sampwidth:
472                  raise Error("sample width not specified")
473              if not self._framerate:
474                  raise Error("sampling rate not specified")
475              self._write_header(datasize)
476  
477      def _write_header(self, initlength):
478          assert not self._headerwritten
479          self._file.write(b"RIFF")
480          if not self._nframes:
481              self._nframes = initlength // (self._nchannels * self._sampwidth)
482          self._datalength = self._nframes * self._nchannels * self._sampwidth
483          try:
484              self._form_length_pos = self._file.tell()
485          except (AttributeError, OSError):
486              self._form_length_pos = None
487          self._file.write(
488              struct.pack(
489                  "<L4s4sLHHLLHH4s",
490                  36 + self._datalength,
491                  b"WAVE",
492                  b"fmt ",
493                  16,
494                  WAVE_FORMAT_PCM,
495                  self._nchannels,
496                  self._framerate,
497                  self._nchannels * self._framerate * self._sampwidth,
498                  self._nchannels * self._sampwidth,
499                  self._sampwidth * 8,
500                  b"data",
501              )
502          )
503          if self._form_length_pos is not None:
504              self._data_length_pos = self._file.tell()
505          self._file.write(struct.pack("<L", self._datalength))
506          self._headerwritten = True
507  
508      def _patchheader(self):
509          assert self._headerwritten
510          if self._datawritten == self._datalength:
511              return
512          curpos = self._file.tell()
513          self._file.seek(self._form_length_pos, 0)
514          self._file.write(struct.pack("<L", 36 + self._datawritten))
515          self._file.seek(self._data_length_pos, 0)
516          self._file.write(struct.pack("<L", self._datawritten))
517          self._file.seek(curpos, 0)
518          self._datalength = self._datawritten
519  
520  
521  def open(f, mode=None):  # pylint: disable=redefined-builtin
522      """Open a wave file in reading (default) or writing (``mode="w"``) mode.
523  
524      The argument may be a filename or an open file.
525  
526      In reading mode, returns a `Wave_read` object.
527      In writing mode, returns a `Wave_write` object.
528      """
529      if mode is None:
530          if hasattr(f, "mode"):
531              mode = f.mode
532          else:
533              mode = "rb"
534      if mode in ("r", "rb"):
535          return Wave_read(f)
536      elif mode in ("w", "wb"):
537          return Wave_write(f)
538      else:
539          raise Error("mode must be 'r', 'rb', 'w', or 'wb'")