osv_check.py
1 """OSV malware check for MCP extension packages. 2 3 Before launching an MCP server via npx/uvx, queries the OSV (Open Source 4 Vulnerabilities) API to check if the package has any known malware advisories 5 (MAL-* IDs). Regular CVEs are ignored — only confirmed malware is blocked. 6 7 The API is free, public, and maintained by Google. Typical latency is ~300ms. 8 Fail-open: network errors allow the package to proceed. 9 10 Inspired by Block/goose's extension malware check. 11 """ 12 13 import json 14 import logging 15 import os 16 import re 17 import urllib.request 18 from typing import Optional, Tuple 19 20 logger = logging.getLogger(__name__) 21 22 _OSV_ENDPOINT = os.getenv("OSV_ENDPOINT", "https://api.osv.dev/v1/query") 23 _TIMEOUT = 10 # seconds 24 25 26 def check_package_for_malware( 27 command: str, args: list 28 ) -> Optional[str]: 29 """Check if an MCP server package has known malware advisories. 30 31 Inspects the *command* (e.g. ``npx``, ``uvx``) and *args* to infer the 32 package name and ecosystem. Queries the OSV API for MAL-* advisories. 33 34 Returns: 35 An error message string if malware is found, or None if clean/unknown. 36 Returns None (allow) on network errors or unrecognized commands. 37 """ 38 ecosystem = _infer_ecosystem(command) 39 if not ecosystem: 40 return None # not npx/uvx — skip 41 42 package, version = _parse_package_from_args(args, ecosystem) 43 if not package: 44 return None 45 46 try: 47 malware = _query_osv(package, ecosystem, version) 48 except Exception as exc: 49 # Fail-open: network errors, timeouts, parse failures → allow 50 logger.debug("OSV check failed for %s/%s (allowing): %s", ecosystem, package, exc) 51 return None 52 53 if malware: 54 ids = ", ".join(m["id"] for m in malware[:3]) 55 summaries = "; ".join( 56 m.get("summary", m["id"])[:100] for m in malware[:3] 57 ) 58 return ( 59 f"BLOCKED: Package '{package}' ({ecosystem}) has known malware " 60 f"advisories: {ids}. Details: {summaries}" 61 ) 62 return None 63 64 65 def _infer_ecosystem(command: str) -> Optional[str]: 66 """Infer package ecosystem from the command name.""" 67 base = os.path.basename(command).lower() 68 if base in ("npx", "npx.cmd"): 69 return "npm" 70 if base in ("uvx", "uvx.cmd", "pipx"): 71 return "PyPI" 72 return None 73 74 75 def _parse_package_from_args( 76 args: list, ecosystem: str 77 ) -> Tuple[Optional[str], Optional[str]]: 78 """Extract package name and optional version from command args. 79 80 Returns (package_name, version) or (None, None) if not parseable. 81 """ 82 if not args: 83 return None, None 84 85 # Skip flags to find the package token 86 package_token = None 87 for arg in args: 88 if not isinstance(arg, str): 89 continue 90 if arg.startswith("-"): 91 continue 92 package_token = arg 93 break 94 95 if not package_token: 96 return None, None 97 98 if ecosystem == "npm": 99 return _parse_npm_package(package_token) 100 elif ecosystem == "PyPI": 101 return _parse_pypi_package(package_token) 102 return package_token, None 103 104 105 def _parse_npm_package(token: str) -> Tuple[Optional[str], Optional[str]]: 106 """Parse npm package: @scope/name@version or name@version.""" 107 if token.startswith("@"): 108 # Scoped: @scope/name@version 109 match = re.match(r"^(@[^/]+/[^@]+)(?:@(.+))?$", token) 110 if match: 111 return match.group(1), match.group(2) 112 return token, None 113 # Unscoped: name@version 114 if "@" in token: 115 parts = token.rsplit("@", 1) 116 name = parts[0] 117 version = parts[1] if len(parts) > 1 and parts[1] != "latest" else None 118 return name, version 119 return token, None 120 121 122 def _parse_pypi_package(token: str) -> Tuple[Optional[str], Optional[str]]: 123 """Parse PyPI package: name==version or name[extras]==version.""" 124 # Strip extras: name[extra1,extra2]==version 125 match = re.match(r"^([a-zA-Z0-9._-]+)(?:\[[^\]]*\])?(?:==(.+))?$", token) 126 if match: 127 return match.group(1), match.group(2) 128 return token, None 129 130 131 def _query_osv( 132 package: str, ecosystem: str, version: Optional[str] = None 133 ) -> list: 134 """Query the OSV API for MAL-* advisories. Returns list of malware vulns.""" 135 payload = {"package": {"name": package, "ecosystem": ecosystem}} 136 if version: 137 payload["version"] = version 138 139 data = json.dumps(payload).encode("utf-8") 140 req = urllib.request.Request( 141 _OSV_ENDPOINT, 142 data=data, 143 headers={ 144 "Content-Type": "application/json", 145 "User-Agent": "hermes-agent-osv-check/1.0", 146 }, 147 method="POST", 148 ) 149 150 with urllib.request.urlopen(req, timeout=_TIMEOUT) as resp: 151 result = json.loads(resp.read()) 152 153 vulns = result.get("vulns", []) 154 # Only malware advisories — ignore regular CVEs 155 return [v for v in vulns if v.get("id", "").startswith("MAL-")]