/ test / dataclasses / test_byte_stream.py
test_byte_stream.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import warnings
  6  
  7  import pytest
  8  
  9  from haystack.dataclasses import ByteStream
 10  
 11  
 12  def test_from_file_path(tmp_path, request):
 13      test_bytes = b"Hello, world!\n"
 14      test_path = tmp_path / request.node.name
 15      with open(test_path, "wb") as fd:
 16          assert fd.write(test_bytes)
 17  
 18      b = ByteStream.from_file_path(test_path)
 19      assert b.data == test_bytes
 20      assert b.mime_type == None
 21  
 22      b = ByteStream.from_file_path(test_path, mime_type="text/plain")
 23      assert b.data == test_bytes
 24      assert b.mime_type == "text/plain"
 25  
 26      b = ByteStream.from_file_path(test_path, meta={"foo": "bar"})
 27      assert b.data == test_bytes
 28      assert b.meta == {"foo": "bar"}
 29  
 30  
 31  @pytest.mark.parametrize(
 32      "file_path, expected_mime_types",
 33      [
 34          ("spam.jpeg", {"image/jpeg"}),
 35          ("spam.jpg", {"image/jpeg"}),
 36          ("spam.png", {"image/png"}),
 37          ("spam.gif", {"image/gif"}),
 38          ("spam.svg", {"image/svg+xml"}),
 39          ("spam.js", {"text/javascript", "application/javascript"}),
 40          ("spam.txt", {"text/plain"}),
 41          ("spam.html", {"text/html"}),
 42          ("spam.htm", {"text/html"}),
 43          ("spam.css", {"text/css"}),
 44          ("spam.csv", {"text/csv"}),
 45          ("spam.md", {"text/markdown"}),  # custom mapping
 46          ("spam.markdown", {"text/markdown"}),  # custom mapping
 47          ("spam.msg", {"application/vnd.ms-outlook"}),  # custom mapping
 48          ("spam.pdf", {"application/pdf"}),
 49          ("spam.xml", {"application/xml", "text/xml"}),
 50          ("spam.json", {"application/json"}),
 51          ("spam.doc", {"application/msword"}),
 52          ("spam.docx", {"application/vnd.openxmlformats-officedocument.wordprocessingml.document"}),
 53          ("spam.xls", {"application/vnd.ms-excel"}),
 54          ("spam.xlsx", {"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"}),
 55          ("spam.ppt", {"application/vnd.ms-powerpoint"}),
 56          ("spam.pptx", {"application/vnd.openxmlformats-officedocument.presentationml.presentation"}),
 57      ],
 58  )
 59  def test_from_file_path_guess_mime_type(file_path, expected_mime_types, tmp_path):
 60      test_file = tmp_path / file_path
 61      test_file.touch()
 62  
 63      b = ByteStream.from_file_path(test_file, guess_mime_type=True)
 64      assert b.mime_type in expected_mime_types
 65  
 66  
 67  def test_explicit_mime_type_is_not_overwritten_by_guessing(tmp_path):
 68      # create empty file with correct extension
 69      test_file = tmp_path / "sample.md"
 70      test_file.touch()
 71  
 72      explicit_mime_type = "text/x-rst"
 73      b = ByteStream.from_file_path(test_file, mime_type=explicit_mime_type, guess_mime_type=True)
 74      assert b.mime_type == explicit_mime_type
 75  
 76  
 77  def test_from_string():
 78      test_string = "Hello, world!"
 79      b = ByteStream.from_string(test_string)
 80      assert b.data.decode() == test_string
 81      assert b.mime_type == None
 82  
 83      b = ByteStream.from_string(test_string, mime_type="text/plain")
 84      assert b.data.decode() == test_string
 85      assert b.mime_type == "text/plain"
 86  
 87      b = ByteStream.from_string(test_string, meta={"foo": "bar"})
 88      assert b.data.decode() == test_string
 89      assert b.meta == {"foo": "bar"}
 90  
 91  
 92  def test_to_string():
 93      test_string = "Hello, world!"
 94      b = ByteStream.from_string(test_string)
 95      assert b.to_string() == test_string
 96  
 97  
 98  def test_to_from_string_encoding():
 99      test_string = "Hello Baščaršija!"
100      with pytest.raises(UnicodeEncodeError):
101          ByteStream.from_string(test_string, encoding="ISO-8859-1")
102  
103      bs = ByteStream.from_string(test_string)  # default encoding is utf-8
104  
105      assert bs.to_string(encoding="ISO-8859-1") != test_string
106      assert bs.to_string(encoding="utf-8") == test_string
107  
108  
109  def test_to_string_encoding_error():
110      # test that it raises ValueError if the encoding is not valid
111      b = ByteStream.from_string("Hello, world!")
112      with pytest.raises(UnicodeDecodeError):
113          b.to_string("utf-16")
114  
115  
116  def test_to_file(tmp_path, request):
117      test_str = "Hello, world!\n"
118      test_path = tmp_path / request.node.name
119  
120      ByteStream(test_str.encode()).to_file(test_path)
121      with open(test_path, "rb") as fd:
122          assert fd.read().decode() == test_str
123  
124  
125  def test_str_truncation():
126      test_str = "1234567890" * 100
127      b = ByteStream.from_string(test_str, mime_type="text/plain", meta={"foo": "bar"})
128      string_repr = str(b)
129      assert len(string_repr) < 200
130      assert "text/plain" in string_repr
131      assert "foo" in string_repr
132  
133  
134  def test_to_dict():
135      test_str = "Hello, world!"
136      b = ByteStream.from_string(test_str, mime_type="text/plain", meta={"foo": "bar"})
137      d = b.to_dict()
138      assert d["data"] == list(test_str.encode())
139      assert d["mime_type"] == "text/plain"
140      assert d["meta"] == {"foo": "bar"}
141  
142  
143  def test_to_trace_dict():
144      b = ByteStream(data=b"Hello, world!", mime_type="text/plain", meta={"foo": "bar"})
145      d = b._to_trace_dict()
146      assert d["data"] == "Binary data (13 bytes)"
147      assert d["mime_type"] == "text/plain"
148      assert d["meta"] == {"foo": "bar"}
149  
150  
151  def test_from_dict():
152      test_str = "Hello, world!"
153      b = ByteStream.from_string(test_str, mime_type="text/plain", meta={"foo": "bar"})
154      d = b.to_dict()
155      b2 = ByteStream.from_dict(d)
156      assert b2.data == b.data
157      assert b2.mime_type == b.mime_type
158      assert b2.meta == b.meta
159      assert str(b2) == str(b)
160  
161  
162  def test_no_warning_on_init():
163      with warnings.catch_warnings():
164          warnings.simplefilter("error", Warning)
165          ByteStream(data=b"hello", mime_type="text/plain", meta={"k": "v"})
166  
167  
168  def test_warn_on_inplace_mutation():
169      b = ByteStream(data=b"hello")
170      with pytest.warns(Warning, match="dataclasses.replace"):
171          b.data = b"world"