conftest.py
1 """Configuration for the pytest test suite.""" 2 3 from __future__ import annotations 4 5 import json 6 import os 7 import random 8 import subprocess 9 import sys 10 import time 11 from contextlib import suppress 12 from pathlib import Path 13 from typing import TYPE_CHECKING 14 15 import psutil 16 import pytest 17 import requests 18 19 from aria2p import API, Client, enable_logger 20 from tests import CONFIGS_DIR, SESSIONS_DIR 21 22 if TYPE_CHECKING: 23 from collections.abc import Iterator 24 25 26 @pytest.fixture(autouse=True) 27 def tests_logs(request: pytest.FixtureRequest) -> None: 28 # put logs in tests/logs 29 log_path = Path("tests") / "logs" 30 31 # tidy logs in subdirectories based on test module and class names 32 module = request.module 33 class_ = request.cls 34 name = request.node.name + ".log" 35 36 if module: 37 log_path /= module.__name__.replace("tests.", "") 38 if class_: 39 log_path /= class_.__name__ 40 41 log_path.mkdir(parents=True, exist_ok=True) 42 43 # append last part of the name and enable logger 44 log_path /= name 45 if log_path.exists(): 46 log_path.unlink() 47 enable_logger(sink=str(log_path), level=os.environ.get("PYTEST_LOG_LEVEL", "TRACE")) 48 49 50 def spawn_and_wait_server(port: int = 8779) -> subprocess.Popen: 51 process = subprocess.Popen( # noqa: S603 52 [ 53 sys.executable, 54 "-m", 55 "uvicorn", 56 "tests.http_server:app", 57 "--port", 58 str(port), 59 ], 60 stdout=subprocess.DEVNULL, 61 stderr=subprocess.DEVNULL, 62 ) 63 while True: 64 try: 65 requests.get(f"http://localhost:{port}/1024") # noqa: S113 66 except: # noqa: E722 67 time.sleep(0.1) 68 else: 69 break 70 return process 71 72 73 @pytest.fixture(scope="session", autouse=True) 74 def http_server(tmp_path_factory: pytest.TempPathFactory, worker_id: str) -> Iterator: 75 if worker_id == "master": 76 # single worker: just run the HTTP server 77 process = spawn_and_wait_server() 78 yield process 79 process.kill() 80 process.wait() 81 return 82 83 # get the temp directory shared by all workers 84 root_tmp_dir = tmp_path_factory.getbasetemp().parent 85 86 # try to get a lock 87 lock = root_tmp_dir / "lock" 88 try: 89 lock.mkdir(exist_ok=False) 90 except FileExistsError: 91 yield # failed, don't run the HTTP server 92 return 93 94 # got the lock, run the HTTP server 95 process = spawn_and_wait_server() 96 yield process 97 process.kill() 98 process.wait() 99 100 101 class Aria2Server: 102 def __init__( 103 self, 104 tmp_dir: Path, 105 port: int, 106 config: str | Path | None = None, 107 session: str | Path | list[str] | None = None, 108 secret: str = "", 109 ) -> None: 110 """Initialize the server. 111 112 Parameters: 113 tmp_dir: Temporary download directory. 114 port: Server port. 115 config: aria2c configuration file. 116 session: aria2c session file. 117 secret: Server secret. 118 """ 119 self.tmp_dir = tmp_dir 120 self.port = port 121 122 # create the command used to launch an aria2c process 123 command = [ 124 "aria2c", 125 f"--dir={self.tmp_dir}", 126 "--file-allocation=none", 127 "--quiet", 128 "--enable-rpc=true", 129 f"--rpc-listen-port={self.port}", 130 ] 131 if config: 132 command.append(f"--conf-path={config}") 133 else: 134 # command.append("--no-conf") 135 config = CONFIGS_DIR / "default.conf" 136 command.append(f"--conf-path={config}") 137 if session: 138 if isinstance(session, list): 139 session_path = self.tmp_dir / "_session.txt" 140 with open(session_path, "w") as stream: 141 stream.write("\n".join(session)) 142 command.append(f"--input-file={session_path}") 143 else: 144 session_path = SESSIONS_DIR / session 145 if not session_path.exists(): 146 raise ValueError(f"no such session: {session}") 147 command.append(f"--input-file={session_path}") 148 if secret: 149 command.append(f"--rpc-secret={secret}") 150 151 self.command = command 152 self.process: subprocess.Popen | None = None 153 154 # create the client with port 155 self.client = Client(port=self.port, secret=secret, timeout=20) 156 157 # create the API instance 158 self.api = API(self.client) 159 160 def __enter__(self) -> Aria2Server: # noqa: PYI034 161 self.start() 162 return self 163 164 def __exit__(self, exc_type, exc_val, exc_tb): # noqa: ANN001 165 self.destroy(force=True) 166 167 def reach(self, retries: int = 5) -> bool: 168 while retries: 169 try: 170 self.client.list_methods() 171 except requests.ConnectionError: 172 time.sleep(0.1) 173 retries -= 1 174 else: 175 break 176 else: 177 return False 178 return True 179 180 def start(self) -> None: 181 # Make sure we kill any remaining aria2c process using the same port. 182 for proc in psutil.process_iter(): 183 try: 184 cmdline = proc.cmdline() 185 if "aria2c" in cmdline and f"--rpc-listen-port={self.port}" in cmdline: 186 proc.kill() 187 proc.wait() 188 break 189 except psutil.NoSuchProcess: 190 pass 191 192 # Make sure we start the new process. 193 while True: 194 self.process = subprocess.Popen(self.command) # noqa: S603 195 if self.reach(retries=10): 196 break 197 self.kill() 198 199 def wait(self) -> None: 200 if self.process: 201 while True: 202 try: 203 self.process.wait() 204 except subprocess.TimeoutExpired: 205 pass 206 else: 207 break 208 209 def terminate(self) -> None: 210 if self.process: 211 self.process.terminate() 212 self.wait() 213 214 def kill(self) -> None: 215 if self.process: 216 self.process.kill() 217 self.wait() 218 219 def rmdir(self, directory: Path | None = None) -> None: 220 if directory is None: 221 directory = self.tmp_dir 222 for item in directory.iterdir(): 223 if item.is_dir(): 224 self.rmdir(item) 225 else: 226 item.unlink() 227 directory.rmdir() 228 229 def destroy(self, *, force: bool = False) -> None: 230 if force: 231 self.kill() 232 else: 233 self.terminate() 234 self.rmdir() 235 236 237 ports_file = Path(".ports.json") 238 lock_dir = Path(".lockdir") 239 240 241 def get_lock() -> None: 242 while True: 243 try: 244 lock_dir.mkdir(exist_ok=False) 245 except FileExistsError: 246 time.sleep(0.025) 247 else: 248 break 249 250 251 def release_lock() -> None: 252 with suppress(FileNotFoundError): 253 lock_dir.rmdir() 254 255 256 def get_random_port() -> int: 257 return random.randint(15000, 16000) # noqa: S311 258 259 260 def get_current_ports() -> list[int]: 261 try: 262 return json.loads(ports_file.read_text()) 263 except FileNotFoundError: 264 return [] 265 266 267 def set_current_ports(ports: list[int]) -> None: 268 ports_file.write_text(json.dumps(ports)) 269 270 271 def reserve_port() -> int: 272 get_lock() 273 274 ports = get_current_ports() 275 port_number = get_random_port() 276 while port_number in ports: 277 port_number = get_random_port() 278 ports.append(port_number) 279 set_current_ports(ports) 280 281 release_lock() 282 return port_number 283 284 285 def release_port(port_number: int) -> None: 286 get_lock() 287 ports = get_current_ports() 288 ports.remove(port_number) 289 set_current_ports(ports) 290 release_lock() 291 292 293 @pytest.fixture 294 def port() -> Iterator[int]: 295 port_number = reserve_port() 296 yield port_number 297 release_port(port_number) 298 299 300 @pytest.fixture 301 def server(tmp_path: Path, port: int) -> Iterator[Aria2Server]: 302 with Aria2Server(tmp_path, port) as server: 303 yield server 304 305 306 @pytest.fixture(scope="session", autouse=True) 307 def _setup() -> None: 308 ports_file.unlink(missing_ok=True) 309 release_lock()