testserialize.py
1 """ 2 Serialize module tests 3 """ 4 5 import os 6 import unittest 7 8 from unittest.mock import patch 9 10 from txtai.serialize import Serialize, SerializeFactory 11 12 13 class TestSerialize(unittest.TestCase): 14 """ 15 Serialize tests. 16 """ 17 18 def testNotImplemented(self): 19 """ 20 Test exceptions for non-implemented methods 21 """ 22 23 serialize = Serialize() 24 25 self.assertRaises(NotImplementedError, serialize.loadstream, None) 26 self.assertRaises(NotImplementedError, serialize.savestream, None, None) 27 self.assertRaises(NotImplementedError, serialize.loadbytes, None) 28 self.assertRaises(NotImplementedError, serialize.savebytes, None) 29 30 def testMessagePack(self): 31 """ 32 Test MessagePack encoder 33 """ 34 35 serializer = SerializeFactory.create() 36 self.assertEqual(serializer.loadbytes(serializer.savebytes("test")), "test") 37 38 def testPickleDisabled(self): 39 """ 40 Test disabled pickle serialization 41 """ 42 43 # Validate an error is raised 44 with self.assertRaises(ValueError): 45 serializer = SerializeFactory.create("pickle", allowpickle=True) 46 data = serializer.savebytes("Test") 47 48 serializer = SerializeFactory.create("pickle") 49 serializer.loadbytes(data) 50 51 @patch.dict(os.environ, {"ALLOW_PICKLE": "True"}) 52 def testPickleEnabled(self): 53 """ 54 Test enabled pickle serialization 55 """ 56 57 # Validate a warning is raised 58 with self.assertWarns(RuntimeWarning): 59 serializer = SerializeFactory.create("pickle") 60 data = serializer.savebytes("Test") 61 serializer.loadbytes(data)