/ test / python / testarchive.py
testarchive.py
  1  """
  2  Compress module tests
  3  """
  4  
  5  import os
  6  import tarfile
  7  import tempfile
  8  import unittest
  9  
 10  from zipfile import ZipFile, ZIP_DEFLATED
 11  
 12  from txtai.archive import ArchiveFactory, Compress
 13  
 14  # pylint: disable=C0411
 15  from utils import Utils
 16  
 17  
 18  class TestArchive(unittest.TestCase):
 19      """
 20      Archive tests.
 21      """
 22  
 23      def testDirectory(self):
 24          """
 25          Test directory included in compressed files
 26          """
 27  
 28          for extension in ["tar", "zip"]:
 29              # Create archive instance
 30              archive = ArchiveFactory.create()
 31  
 32              # Create subdirectory in archive working path
 33              path = os.path.join(archive.path(), "dir")
 34              os.makedirs(path, exist_ok=True)
 35  
 36              # Create file in archive working path
 37              with open(os.path.join(path, "test"), "w", encoding="utf-8") as f:
 38                  f.write("test")
 39  
 40              # Save archive
 41              path = os.path.join(tempfile.gettempdir(), f"subdir.{extension}")
 42              archive.save(path)
 43  
 44              # Extract files from archive
 45              archive = ArchiveFactory.create()
 46              archive.load(path)
 47  
 48              # Check if file properly extracted
 49              path = os.path.join(archive.path(), "dir", "test")
 50              self.assertTrue(os.path.exists(path))
 51  
 52      def testInvalidTarLink(self):
 53          """
 54          Test invalid tar file with symlinks
 55          """
 56  
 57          symlink = os.path.join(tempfile.gettempdir(), "link")
 58  
 59          # Remove symlink if it already exists
 60          try:
 61              os.remove(symlink)
 62          except OSError:
 63              pass
 64  
 65          # Create symlink and add to TAR file
 66          os.symlink(os.path.join(tempfile.gettempdir(), "noexist"), symlink)
 67  
 68          path = os.path.join(tempfile.gettempdir(), "badtarlink")
 69          with tarfile.open(path, "w") as tar:
 70              tar.add(symlink, arcname="l")
 71  
 72          archive = ArchiveFactory.create()
 73  
 74          # Validate error is thrown for file
 75          with self.assertRaises(IOError):
 76              archive.load(path, "tar")
 77  
 78      def testInvalidTarPath(self):
 79          """
 80          Test invalid tar file with a path outside of base directory
 81          """
 82  
 83          path = os.path.join(tempfile.gettempdir(), "badtarpath")
 84          with tarfile.open(path, "w") as tar:
 85              tar.add(Utils.PATH, arcname="..")
 86  
 87          archive = ArchiveFactory.create()
 88  
 89          # Validate error is thrown for file
 90          with self.assertRaises(IOError):
 91              archive.load(path, "tar")
 92  
 93      def testInvalidZipPath(self):
 94          """
 95          Test invalid zip file with a path outside of base directory
 96          """
 97  
 98          path = os.path.join(tempfile.gettempdir(), "badzippath")
 99          with ZipFile(path, "w", ZIP_DEFLATED) as zfile:
100              zfile.write(Utils.PATH + "/article.pdf", arcname="../article.pdf")
101  
102          archive = ArchiveFactory.create()
103  
104          # Validate error is thrown for file
105          with self.assertRaises(IOError):
106              archive.load(path, "zip")
107  
108      def testNotImplemented(self):
109          """
110          Test exceptions for non-implemented methods
111          """
112  
113          compress = Compress()
114  
115          self.assertRaises(NotImplementedError, compress.pack, None, None)
116          self.assertRaises(NotImplementedError, compress.unpack, None, None)