testarchive.py
1 """ 2 Compress module tests 3 """ 4 5 import os 6 import tarfile 7 import tempfile 8 import unittest 9 10 from zipfile import ZipFile, ZIP_DEFLATED 11 12 from txtai.archive import ArchiveFactory, Compress 13 14 # pylint: disable=C0411 15 from utils import Utils 16 17 18 class TestArchive(unittest.TestCase): 19 """ 20 Archive tests. 21 """ 22 23 def testDirectory(self): 24 """ 25 Test directory included in compressed files 26 """ 27 28 for extension in ["tar", "zip"]: 29 # Create archive instance 30 archive = ArchiveFactory.create() 31 32 # Create subdirectory in archive working path 33 path = os.path.join(archive.path(), "dir") 34 os.makedirs(path, exist_ok=True) 35 36 # Create file in archive working path 37 with open(os.path.join(path, "test"), "w", encoding="utf-8") as f: 38 f.write("test") 39 40 # Save archive 41 path = os.path.join(tempfile.gettempdir(), f"subdir.{extension}") 42 archive.save(path) 43 44 # Extract files from archive 45 archive = ArchiveFactory.create() 46 archive.load(path) 47 48 # Check if file properly extracted 49 path = os.path.join(archive.path(), "dir", "test") 50 self.assertTrue(os.path.exists(path)) 51 52 def testInvalidTarLink(self): 53 """ 54 Test invalid tar file with symlinks 55 """ 56 57 symlink = os.path.join(tempfile.gettempdir(), "link") 58 59 # Remove symlink if it already exists 60 try: 61 os.remove(symlink) 62 except OSError: 63 pass 64 65 # Create symlink and add to TAR file 66 os.symlink(os.path.join(tempfile.gettempdir(), "noexist"), symlink) 67 68 path = os.path.join(tempfile.gettempdir(), "badtarlink") 69 with tarfile.open(path, "w") as tar: 70 tar.add(symlink, arcname="l") 71 72 archive = ArchiveFactory.create() 73 74 # Validate error is thrown for file 75 with self.assertRaises(IOError): 76 archive.load(path, "tar") 77 78 def testInvalidTarPath(self): 79 """ 80 Test invalid tar file with a path outside of base directory 81 """ 82 83 path = os.path.join(tempfile.gettempdir(), "badtarpath") 84 with tarfile.open(path, "w") as tar: 85 tar.add(Utils.PATH, arcname="..") 86 87 archive = ArchiveFactory.create() 88 89 # Validate error is thrown for file 90 with self.assertRaises(IOError): 91 archive.load(path, "tar") 92 93 def testInvalidZipPath(self): 94 """ 95 Test invalid zip file with a path outside of base directory 96 """ 97 98 path = os.path.join(tempfile.gettempdir(), "badzippath") 99 with ZipFile(path, "w", ZIP_DEFLATED) as zfile: 100 zfile.write(Utils.PATH + "/article.pdf", arcname="../article.pdf") 101 102 archive = ArchiveFactory.create() 103 104 # Validate error is thrown for file 105 with self.assertRaises(IOError): 106 archive.load(path, "zip") 107 108 def testNotImplemented(self): 109 """ 110 Test exceptions for non-implemented methods 111 """ 112 113 compress = Compress() 114 115 self.assertRaises(NotImplementedError, compress.pack, None, None) 116 self.assertRaises(NotImplementedError, compress.unpack, None, None)