/ tests / conftest.py
conftest.py
  1  """Configuration for the pytest test suite."""
  2  
  3  from __future__ import annotations
  4  
  5  import json
  6  import os
  7  import random
  8  import subprocess
  9  import sys
 10  import time
 11  from contextlib import suppress
 12  from pathlib import Path
 13  from typing import TYPE_CHECKING
 14  
 15  import psutil
 16  import pytest
 17  import requests
 18  
 19  from aria2p import API, Client, enable_logger
 20  from tests import CONFIGS_DIR, SESSIONS_DIR
 21  
 22  if TYPE_CHECKING:
 23      from collections.abc import Iterator
 24  
 25  
 26  @pytest.fixture(autouse=True)
 27  def tests_logs(request: pytest.FixtureRequest) -> None:
 28      # put logs in tests/logs
 29      log_path = Path("tests") / "logs"
 30  
 31      # tidy logs in subdirectories based on test module and class names
 32      module = request.module
 33      class_ = request.cls
 34      name = request.node.name + ".log"
 35  
 36      if module:
 37          log_path /= module.__name__.replace("tests.", "")
 38      if class_:
 39          log_path /= class_.__name__
 40  
 41      log_path.mkdir(parents=True, exist_ok=True)
 42  
 43      # append last part of the name and enable logger
 44      log_path /= name
 45      if log_path.exists():
 46          log_path.unlink()
 47      enable_logger(sink=str(log_path), level=os.environ.get("PYTEST_LOG_LEVEL", "TRACE"))
 48  
 49  
 50  def spawn_and_wait_server(port: int = 8779) -> subprocess.Popen:
 51      process = subprocess.Popen(  # noqa: S603
 52          [
 53              sys.executable,
 54              "-m",
 55              "uvicorn",
 56              "tests.http_server:app",
 57              "--port",
 58              str(port),
 59          ],
 60          stdout=subprocess.DEVNULL,
 61          stderr=subprocess.DEVNULL,
 62      )
 63      while True:
 64          try:
 65              requests.get(f"http://localhost:{port}/1024")  # noqa: S113
 66          except:  # noqa: E722
 67              time.sleep(0.1)
 68          else:
 69              break
 70      return process
 71  
 72  
 73  @pytest.fixture(scope="session", autouse=True)
 74  def http_server(tmp_path_factory: pytest.TempPathFactory, worker_id: str) -> Iterator:
 75      if worker_id == "master":
 76          # single worker: just run the HTTP server
 77          process = spawn_and_wait_server()
 78          yield process
 79          process.kill()
 80          process.wait()
 81          return
 82  
 83      # get the temp directory shared by all workers
 84      root_tmp_dir = tmp_path_factory.getbasetemp().parent
 85  
 86      # try to get a lock
 87      lock = root_tmp_dir / "lock"
 88      try:
 89          lock.mkdir(exist_ok=False)
 90      except FileExistsError:
 91          yield  # failed, don't run the HTTP server
 92          return
 93  
 94      # got the lock, run the HTTP server
 95      process = spawn_and_wait_server()
 96      yield process
 97      process.kill()
 98      process.wait()
 99  
100  
101  class Aria2Server:
102      def __init__(
103          self,
104          tmp_dir: Path,
105          port: int,
106          config: str | Path | None = None,
107          session: str | Path | list[str] | None = None,
108          secret: str = "",
109      ) -> None:
110          """Initialize the server.
111  
112          Parameters:
113              tmp_dir: Temporary download directory.
114              port: Server port.
115              config: aria2c configuration file.
116              session: aria2c session file.
117              secret: Server secret.
118          """
119          self.tmp_dir = tmp_dir
120          self.port = port
121  
122          # create the command used to launch an aria2c process
123          command = [
124              "aria2c",
125              f"--dir={self.tmp_dir}",
126              "--file-allocation=none",
127              "--quiet",
128              "--enable-rpc=true",
129              f"--rpc-listen-port={self.port}",
130          ]
131          if config:
132              command.append(f"--conf-path={config}")
133          else:
134              # command.append("--no-conf")
135              config = CONFIGS_DIR / "default.conf"
136              command.append(f"--conf-path={config}")
137          if session:
138              if isinstance(session, list):
139                  session_path = self.tmp_dir / "_session.txt"
140                  with open(session_path, "w") as stream:
141                      stream.write("\n".join(session))
142                  command.append(f"--input-file={session_path}")
143              else:
144                  session_path = SESSIONS_DIR / session
145                  if not session_path.exists():
146                      raise ValueError(f"no such session: {session}")
147                  command.append(f"--input-file={session_path}")
148          if secret:
149              command.append(f"--rpc-secret={secret}")
150  
151          self.command = command
152          self.process: subprocess.Popen | None = None
153  
154          # create the client with port
155          self.client = Client(port=self.port, secret=secret, timeout=20)
156  
157          # create the API instance
158          self.api = API(self.client)
159  
160      def __enter__(self) -> Aria2Server:  # noqa: PYI034
161          self.start()
162          return self
163  
164      def __exit__(self, exc_type, exc_val, exc_tb):  # noqa: ANN001
165          self.destroy(force=True)
166  
167      def reach(self, retries: int = 5) -> bool:
168          while retries:
169              try:
170                  self.client.list_methods()
171              except requests.ConnectionError:
172                  time.sleep(0.1)
173                  retries -= 1
174              else:
175                  break
176          else:
177              return False
178          return True
179  
180      def start(self) -> None:
181          # Make sure we kill any remaining aria2c process using the same port.
182          for proc in psutil.process_iter():
183              try:
184                  cmdline = proc.cmdline()
185                  if "aria2c" in cmdline and f"--rpc-listen-port={self.port}" in cmdline:
186                      proc.kill()
187                      proc.wait()
188                      break
189              except psutil.NoSuchProcess:
190                  pass
191  
192          # Make sure we start the new process.
193          while True:
194              self.process = subprocess.Popen(self.command)  # noqa: S603
195              if self.reach(retries=10):
196                  break
197              self.kill()
198  
199      def wait(self) -> None:
200          if self.process:
201              while True:
202                  try:
203                      self.process.wait()
204                  except subprocess.TimeoutExpired:
205                      pass
206                  else:
207                      break
208  
209      def terminate(self) -> None:
210          if self.process:
211              self.process.terminate()
212          self.wait()
213  
214      def kill(self) -> None:
215          if self.process:
216              self.process.kill()
217          self.wait()
218  
219      def rmdir(self, directory: Path | None = None) -> None:
220          if directory is None:
221              directory = self.tmp_dir
222          for item in directory.iterdir():
223              if item.is_dir():
224                  self.rmdir(item)
225              else:
226                  item.unlink()
227          directory.rmdir()
228  
229      def destroy(self, *, force: bool = False) -> None:
230          if force:
231              self.kill()
232          else:
233              self.terminate()
234          self.rmdir()
235  
236  
237  ports_file = Path(".ports.json")
238  lock_dir = Path(".lockdir")
239  
240  
241  def get_lock() -> None:
242      while True:
243          try:
244              lock_dir.mkdir(exist_ok=False)
245          except FileExistsError:
246              time.sleep(0.025)
247          else:
248              break
249  
250  
251  def release_lock() -> None:
252      with suppress(FileNotFoundError):
253          lock_dir.rmdir()
254  
255  
256  def get_random_port() -> int:
257      return random.randint(15000, 16000)  # noqa: S311
258  
259  
260  def get_current_ports() -> list[int]:
261      try:
262          return json.loads(ports_file.read_text())
263      except FileNotFoundError:
264          return []
265  
266  
267  def set_current_ports(ports: list[int]) -> None:
268      ports_file.write_text(json.dumps(ports))
269  
270  
271  def reserve_port() -> int:
272      get_lock()
273  
274      ports = get_current_ports()
275      port_number = get_random_port()
276      while port_number in ports:
277          port_number = get_random_port()
278      ports.append(port_number)
279      set_current_ports(ports)
280  
281      release_lock()
282      return port_number
283  
284  
285  def release_port(port_number: int) -> None:
286      get_lock()
287      ports = get_current_ports()
288      ports.remove(port_number)
289      set_current_ports(ports)
290      release_lock()
291  
292  
293  @pytest.fixture
294  def port() -> Iterator[int]:
295      port_number = reserve_port()
296      yield port_number
297      release_port(port_number)
298  
299  
300  @pytest.fixture
301  def server(tmp_path: Path, port: int) -> Iterator[Aria2Server]:
302      with Aria2Server(tmp_path, port) as server:
303          yield server
304  
305  
306  @pytest.fixture(scope="session", autouse=True)
307  def _setup() -> None:
308      ports_file.unlink(missing_ok=True)
309      release_lock()