test_byte_stream.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import warnings 6 7 import pytest 8 9 from haystack.dataclasses import ByteStream 10 11 12 def test_from_file_path(tmp_path, request): 13 test_bytes = b"Hello, world!\n" 14 test_path = tmp_path / request.node.name 15 with open(test_path, "wb") as fd: 16 assert fd.write(test_bytes) 17 18 b = ByteStream.from_file_path(test_path) 19 assert b.data == test_bytes 20 assert b.mime_type == None 21 22 b = ByteStream.from_file_path(test_path, mime_type="text/plain") 23 assert b.data == test_bytes 24 assert b.mime_type == "text/plain" 25 26 b = ByteStream.from_file_path(test_path, meta={"foo": "bar"}) 27 assert b.data == test_bytes 28 assert b.meta == {"foo": "bar"} 29 30 31 @pytest.mark.parametrize( 32 "file_path, expected_mime_types", 33 [ 34 ("spam.jpeg", {"image/jpeg"}), 35 ("spam.jpg", {"image/jpeg"}), 36 ("spam.png", {"image/png"}), 37 ("spam.gif", {"image/gif"}), 38 ("spam.svg", {"image/svg+xml"}), 39 ("spam.js", {"text/javascript", "application/javascript"}), 40 ("spam.txt", {"text/plain"}), 41 ("spam.html", {"text/html"}), 42 ("spam.htm", {"text/html"}), 43 ("spam.css", {"text/css"}), 44 ("spam.csv", {"text/csv"}), 45 ("spam.md", {"text/markdown"}), # custom mapping 46 ("spam.markdown", {"text/markdown"}), # custom mapping 47 ("spam.msg", {"application/vnd.ms-outlook"}), # custom mapping 48 ("spam.pdf", {"application/pdf"}), 49 ("spam.xml", {"application/xml", "text/xml"}), 50 ("spam.json", {"application/json"}), 51 ("spam.doc", {"application/msword"}), 52 ("spam.docx", {"application/vnd.openxmlformats-officedocument.wordprocessingml.document"}), 53 ("spam.xls", {"application/vnd.ms-excel"}), 54 ("spam.xlsx", {"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"}), 55 ("spam.ppt", {"application/vnd.ms-powerpoint"}), 56 ("spam.pptx", {"application/vnd.openxmlformats-officedocument.presentationml.presentation"}), 57 ], 58 ) 59 def test_from_file_path_guess_mime_type(file_path, expected_mime_types, tmp_path): 60 test_file = tmp_path / file_path 61 test_file.touch() 62 63 b = ByteStream.from_file_path(test_file, guess_mime_type=True) 64 assert b.mime_type in expected_mime_types 65 66 67 def test_explicit_mime_type_is_not_overwritten_by_guessing(tmp_path): 68 # create empty file with correct extension 69 test_file = tmp_path / "sample.md" 70 test_file.touch() 71 72 explicit_mime_type = "text/x-rst" 73 b = ByteStream.from_file_path(test_file, mime_type=explicit_mime_type, guess_mime_type=True) 74 assert b.mime_type == explicit_mime_type 75 76 77 def test_from_string(): 78 test_string = "Hello, world!" 79 b = ByteStream.from_string(test_string) 80 assert b.data.decode() == test_string 81 assert b.mime_type == None 82 83 b = ByteStream.from_string(test_string, mime_type="text/plain") 84 assert b.data.decode() == test_string 85 assert b.mime_type == "text/plain" 86 87 b = ByteStream.from_string(test_string, meta={"foo": "bar"}) 88 assert b.data.decode() == test_string 89 assert b.meta == {"foo": "bar"} 90 91 92 def test_to_string(): 93 test_string = "Hello, world!" 94 b = ByteStream.from_string(test_string) 95 assert b.to_string() == test_string 96 97 98 def test_to_from_string_encoding(): 99 test_string = "Hello Baščaršija!" 100 with pytest.raises(UnicodeEncodeError): 101 ByteStream.from_string(test_string, encoding="ISO-8859-1") 102 103 bs = ByteStream.from_string(test_string) # default encoding is utf-8 104 105 assert bs.to_string(encoding="ISO-8859-1") != test_string 106 assert bs.to_string(encoding="utf-8") == test_string 107 108 109 def test_to_string_encoding_error(): 110 # test that it raises ValueError if the encoding is not valid 111 b = ByteStream.from_string("Hello, world!") 112 with pytest.raises(UnicodeDecodeError): 113 b.to_string("utf-16") 114 115 116 def test_to_file(tmp_path, request): 117 test_str = "Hello, world!\n" 118 test_path = tmp_path / request.node.name 119 120 ByteStream(test_str.encode()).to_file(test_path) 121 with open(test_path, "rb") as fd: 122 assert fd.read().decode() == test_str 123 124 125 def test_str_truncation(): 126 test_str = "1234567890" * 100 127 b = ByteStream.from_string(test_str, mime_type="text/plain", meta={"foo": "bar"}) 128 string_repr = str(b) 129 assert len(string_repr) < 200 130 assert "text/plain" in string_repr 131 assert "foo" in string_repr 132 133 134 def test_to_dict(): 135 test_str = "Hello, world!" 136 b = ByteStream.from_string(test_str, mime_type="text/plain", meta={"foo": "bar"}) 137 d = b.to_dict() 138 assert d["data"] == list(test_str.encode()) 139 assert d["mime_type"] == "text/plain" 140 assert d["meta"] == {"foo": "bar"} 141 142 143 def test_to_trace_dict(): 144 b = ByteStream(data=b"Hello, world!", mime_type="text/plain", meta={"foo": "bar"}) 145 d = b._to_trace_dict() 146 assert d["data"] == "Binary data (13 bytes)" 147 assert d["mime_type"] == "text/plain" 148 assert d["meta"] == {"foo": "bar"} 149 150 151 def test_from_dict(): 152 test_str = "Hello, world!" 153 b = ByteStream.from_string(test_str, mime_type="text/plain", meta={"foo": "bar"}) 154 d = b.to_dict() 155 b2 = ByteStream.from_dict(d) 156 assert b2.data == b.data 157 assert b2.mime_type == b.mime_type 158 assert b2.meta == b.meta 159 assert str(b2) == str(b) 160 161 162 def test_no_warning_on_init(): 163 with warnings.catch_warnings(): 164 warnings.simplefilter("error", Warning) 165 ByteStream(data=b"hello", mime_type="text/plain", meta={"k": "v"}) 166 167 168 def test_warn_on_inplace_mutation(): 169 b = ByteStream(data=b"hello") 170 with pytest.warns(Warning, match="dataclasses.replace"): 171 b.data = b"world"