test_email_sms_tools.py
1 """Tests for the send_email + send_sms builtin tools. 2 3 Covers the happy path with mocked transports, plus precondition errors 4 (missing project context, missing config, missing recipient) so the 5 agent gets a clear ERROR string instead of an unhandled exception when 6 an admin forgets a field. 7 8 The tools do their `import smtplib` / `import requests` / `from 9 restai.database import get_db_wrapper` *inside* the function (so the 10 import cost is paid only when the tool is actually invoked). Tests 11 therefore patch the canonical module paths, not the tool module. 12 """ 13 from __future__ import annotations 14 15 import json 16 import smtplib 17 from types import SimpleNamespace 18 from unittest.mock import MagicMock, patch 19 20 import pytest 21 22 from restai.utils.crypto import encrypt_field 23 24 25 def _fake_project(opts: dict): 26 """Stand-in for ProjectDatabase as the tools see it.""" 27 return SimpleNamespace(options=json.dumps(opts)) 28 29 30 def _fake_db(project_obj): 31 db = MagicMock() 32 db.get_project_by_id.return_value = project_obj 33 return db 34 35 36 # ─── send_email ───────────────────────────────────────────────────────── 37 38 def test_send_email_requires_project_context(): 39 from restai.llms.tools.send_email import send_email 40 out = send_email("subject", "body") 41 assert out.startswith("ERROR:") 42 assert "project context" in out 43 44 45 def test_send_email_missing_smtp_config(): 46 from restai.llms.tools.send_email import send_email 47 db = _fake_db(_fake_project({})) 48 with patch("restai.database.get_db_wrapper", return_value=db): 49 out = send_email("hello", "body", _brain=object(), _project_id=42) 50 assert out.startswith("ERROR:") 51 assert "not configured" in out 52 53 54 def test_send_email_missing_recipient(): 55 from restai.llms.tools.send_email import send_email 56 db = _fake_db(_fake_project({"smtp_host": "smtp.example.com", "smtp_from": "bot@x"})) 57 with patch("restai.database.get_db_wrapper", return_value=db): 58 out = send_email("hi", "body", _brain=object(), _project_id=42) 59 assert out.startswith("ERROR:") 60 assert "recipient" in out 61 62 63 def test_send_email_happy_path(): 64 """STARTTLS path (port 587). Ensure smtplib.SMTP is constructed 65 with host/port/timeout and send_message is called.""" 66 from restai.llms.tools.send_email import send_email 67 opts = { 68 "smtp_host": "smtp.example.com", 69 "smtp_port": 587, 70 "smtp_user": "bot@example.com", 71 "smtp_password": encrypt_field("hunter2"), 72 "smtp_from": "bot@example.com", 73 "email_default_to": "admin@example.com", 74 } 75 db = _fake_db(_fake_project(opts)) 76 77 sent = [] 78 class _FakeSMTP: 79 def __init__(self, host, port, timeout=None): 80 sent.append({"host": host, "port": port, "timeout": timeout}) 81 def __enter__(self): return self 82 def __exit__(self, *a): return False 83 def ehlo(self): pass 84 def starttls(self): sent.append({"starttls": True}) 85 def login(self, u, p): sent.append({"login": (u, p)}) 86 def send_message(self, msg): sent.append({"send": msg["To"]}) 87 88 with patch("restai.database.get_db_wrapper", return_value=db), \ 89 patch("smtplib.SMTP", _FakeSMTP): 90 out = send_email("subj", "body", _brain=object(), _project_id=1) 91 92 assert out.startswith("OK:"), out 93 assert sent[0]["host"] == "smtp.example.com" and sent[0]["port"] == 587 94 assert any("starttls" in s for s in sent) 95 assert any("login" in s and s["login"] == ("bot@example.com", "hunter2") for s in sent) 96 assert any("send" in s and s["send"] == "admin@example.com" for s in sent) 97 98 99 def test_send_email_implicit_tls_path(): 100 """Port 465 should use SMTP_SSL, not STARTTLS.""" 101 from restai.llms.tools.send_email import send_email 102 opts = { 103 "smtp_host": "smtp.example.com", 104 "smtp_port": 465, 105 "smtp_from": "bot@example.com", 106 "email_default_to": "admin@example.com", 107 } 108 db = _fake_db(_fake_project(opts)) 109 110 used_ssl = {"flag": False} 111 class _FakeSSL: 112 def __init__(self, host, port, timeout=None): 113 used_ssl["flag"] = True 114 def __enter__(self): return self 115 def __exit__(self, *a): return False 116 def login(self, u, p): pass 117 def send_message(self, msg): pass 118 119 class _FakePlainSMTP: 120 def __init__(self, *a, **kw): 121 raise AssertionError("plain SMTP must not be used on port 465") 122 123 with patch("restai.database.get_db_wrapper", return_value=db), \ 124 patch("smtplib.SMTP_SSL", _FakeSSL), \ 125 patch("smtplib.SMTP", _FakePlainSMTP): 126 out = send_email("subj", "body", _brain=object(), _project_id=1) 127 128 assert out.startswith("OK:"), out 129 assert used_ssl["flag"] is True 130 131 132 def test_send_email_smtp_failure_returns_error_string(): 133 from restai.llms.tools.send_email import send_email 134 opts = { 135 "smtp_host": "smtp.example.com", 136 "smtp_port": 587, 137 "smtp_from": "bot@example.com", 138 "email_default_to": "admin@example.com", 139 } 140 db = _fake_db(_fake_project(opts)) 141 142 def _raise(*a, **kw): 143 raise smtplib.SMTPException("relay refused") 144 145 with patch("restai.database.get_db_wrapper", return_value=db), \ 146 patch("smtplib.SMTP", _raise): 147 out = send_email("subj", "body", _brain=object(), _project_id=1) 148 assert out.startswith("ERROR:"), out 149 assert "relay refused" in out 150 151 152 # ─── send_sms ─────────────────────────────────────────────────────────── 153 154 def test_send_sms_requires_project_context(): 155 from restai.llms.tools.send_sms import send_sms 156 out = send_sms("hi") 157 assert out.startswith("ERROR:") 158 assert "project context" in out 159 160 161 def test_send_sms_missing_config(): 162 from restai.llms.tools.send_sms import send_sms 163 db = _fake_db(_fake_project({})) 164 with patch("restai.database.get_db_wrapper", return_value=db): 165 out = send_sms("hi", _brain=object(), _project_id=1) 166 assert out.startswith("ERROR:") 167 assert "not configured" in out 168 169 170 def test_send_sms_missing_recipient(): 171 from restai.llms.tools.send_sms import send_sms 172 db = _fake_db(_fake_project({ 173 "twilio_account_sid": "AC123", 174 "twilio_auth_token": encrypt_field("tok"), 175 "twilio_from_number": "+15551234567", 176 })) 177 with patch("restai.database.get_db_wrapper", return_value=db): 178 out = send_sms("hi", _brain=object(), _project_id=1) 179 assert out.startswith("ERROR:") 180 assert "recipient" in out 181 182 183 def test_send_sms_happy_path(): 184 from restai.llms.tools.send_sms import send_sms 185 db = _fake_db(_fake_project({ 186 "twilio_account_sid": "AC123", 187 "twilio_auth_token": encrypt_field("tok_secret"), 188 "twilio_from_number": "+15551234567", 189 "sms_default_to": "+351912345678", 190 })) 191 192 captured = {} 193 class _FakeResp: 194 status_code = 201 195 def json(self): return {"sid": "SMxxxx"} 196 def fake_post(url, auth=None, data=None, timeout=None): 197 captured.update({"url": url, "auth": auth, "data": data, "timeout": timeout}) 198 return _FakeResp() 199 200 with patch("restai.database.get_db_wrapper", return_value=db), \ 201 patch("requests.post", fake_post): 202 out = send_sms("hi from agent", _brain=object(), _project_id=1) 203 204 assert out.startswith("OK:"), out 205 assert "SMxxxx" in out 206 assert "AC123/Messages.json" in captured["url"] 207 assert captured["auth"] == ("AC123", "tok_secret") 208 assert captured["data"] == {"From": "+15551234567", "To": "+351912345678", "Body": "hi from agent"} 209 assert captured["timeout"] is not None 210 211 212 def test_send_sms_chunks_long_messages(): 213 from restai.llms.tools.send_sms import send_sms 214 db = _fake_db(_fake_project({ 215 "twilio_account_sid": "AC123", 216 "twilio_auth_token": encrypt_field("tok"), 217 "twilio_from_number": "+15551234567", 218 "sms_default_to": "+351912345678", 219 })) 220 221 calls = [] 222 class _FakeResp: 223 status_code = 201 224 def json(self): return {"sid": "SMxxxx"} 225 def fake_post(url, auth=None, data=None, timeout=None): 226 calls.append(len(data["Body"])) 227 return _FakeResp() 228 229 with patch("restai.database.get_db_wrapper", return_value=db), \ 230 patch("requests.post", fake_post): 231 out = send_sms("x" * 3500, _brain=object(), _project_id=1) 232 assert out.startswith("OK:") 233 # 1600-char chunks → 3 parts (1600 + 1600 + 300). 234 assert calls == [1600, 1600, 300] 235 236 237 def test_send_sms_twilio_error_surfaces(): 238 from restai.llms.tools.send_sms import send_sms 239 db = _fake_db(_fake_project({ 240 "twilio_account_sid": "AC123", 241 "twilio_auth_token": encrypt_field("tok"), 242 "twilio_from_number": "+15551234567", 243 "sms_default_to": "+351912345678", 244 })) 245 246 class _FakeResp: 247 status_code = 400 248 text = '{"message": "from number not owned"}' 249 def json(self): return {"message": "from number not owned", "code": 21603} 250 with patch("restai.database.get_db_wrapper", return_value=db), \ 251 patch("requests.post", lambda *a, **kw: _FakeResp()): 252 out = send_sms("hi", _brain=object(), _project_id=1) 253 assert out.startswith("ERROR:") 254 assert "from number not owned" in out