/ tests / test_email_sms_tools.py
test_email_sms_tools.py
  1  """Tests for the send_email + send_sms builtin tools.
  2  
  3  Covers the happy path with mocked transports, plus precondition errors
  4  (missing project context, missing config, missing recipient) so the
  5  agent gets a clear ERROR string instead of an unhandled exception when
  6  an admin forgets a field.
  7  
  8  The tools do their `import smtplib` / `import requests` / `from
  9  restai.database import get_db_wrapper` *inside* the function (so the
 10  import cost is paid only when the tool is actually invoked). Tests
 11  therefore patch the canonical module paths, not the tool module.
 12  """
 13  from __future__ import annotations
 14  
 15  import json
 16  import smtplib
 17  from types import SimpleNamespace
 18  from unittest.mock import MagicMock, patch
 19  
 20  import pytest
 21  
 22  from restai.utils.crypto import encrypt_field
 23  
 24  
 25  def _fake_project(opts: dict):
 26      """Stand-in for ProjectDatabase as the tools see it."""
 27      return SimpleNamespace(options=json.dumps(opts))
 28  
 29  
 30  def _fake_db(project_obj):
 31      db = MagicMock()
 32      db.get_project_by_id.return_value = project_obj
 33      return db
 34  
 35  
 36  # ─── send_email ─────────────────────────────────────────────────────────
 37  
 38  def test_send_email_requires_project_context():
 39      from restai.llms.tools.send_email import send_email
 40      out = send_email("subject", "body")
 41      assert out.startswith("ERROR:")
 42      assert "project context" in out
 43  
 44  
 45  def test_send_email_missing_smtp_config():
 46      from restai.llms.tools.send_email import send_email
 47      db = _fake_db(_fake_project({}))
 48      with patch("restai.database.get_db_wrapper", return_value=db):
 49          out = send_email("hello", "body", _brain=object(), _project_id=42)
 50      assert out.startswith("ERROR:")
 51      assert "not configured" in out
 52  
 53  
 54  def test_send_email_missing_recipient():
 55      from restai.llms.tools.send_email import send_email
 56      db = _fake_db(_fake_project({"smtp_host": "smtp.example.com", "smtp_from": "bot@x"}))
 57      with patch("restai.database.get_db_wrapper", return_value=db):
 58          out = send_email("hi", "body", _brain=object(), _project_id=42)
 59      assert out.startswith("ERROR:")
 60      assert "recipient" in out
 61  
 62  
 63  def test_send_email_happy_path():
 64      """STARTTLS path (port 587). Ensure smtplib.SMTP is constructed
 65      with host/port/timeout and send_message is called."""
 66      from restai.llms.tools.send_email import send_email
 67      opts = {
 68          "smtp_host": "smtp.example.com",
 69          "smtp_port": 587,
 70          "smtp_user": "bot@example.com",
 71          "smtp_password": encrypt_field("hunter2"),
 72          "smtp_from": "bot@example.com",
 73          "email_default_to": "admin@example.com",
 74      }
 75      db = _fake_db(_fake_project(opts))
 76  
 77      sent = []
 78      class _FakeSMTP:
 79          def __init__(self, host, port, timeout=None):
 80              sent.append({"host": host, "port": port, "timeout": timeout})
 81          def __enter__(self): return self
 82          def __exit__(self, *a): return False
 83          def ehlo(self): pass
 84          def starttls(self): sent.append({"starttls": True})
 85          def login(self, u, p): sent.append({"login": (u, p)})
 86          def send_message(self, msg): sent.append({"send": msg["To"]})
 87  
 88      with patch("restai.database.get_db_wrapper", return_value=db), \
 89           patch("smtplib.SMTP", _FakeSMTP):
 90          out = send_email("subj", "body", _brain=object(), _project_id=1)
 91  
 92      assert out.startswith("OK:"), out
 93      assert sent[0]["host"] == "smtp.example.com" and sent[0]["port"] == 587
 94      assert any("starttls" in s for s in sent)
 95      assert any("login" in s and s["login"] == ("bot@example.com", "hunter2") for s in sent)
 96      assert any("send" in s and s["send"] == "admin@example.com" for s in sent)
 97  
 98  
 99  def test_send_email_implicit_tls_path():
100      """Port 465 should use SMTP_SSL, not STARTTLS."""
101      from restai.llms.tools.send_email import send_email
102      opts = {
103          "smtp_host": "smtp.example.com",
104          "smtp_port": 465,
105          "smtp_from": "bot@example.com",
106          "email_default_to": "admin@example.com",
107      }
108      db = _fake_db(_fake_project(opts))
109  
110      used_ssl = {"flag": False}
111      class _FakeSSL:
112          def __init__(self, host, port, timeout=None):
113              used_ssl["flag"] = True
114          def __enter__(self): return self
115          def __exit__(self, *a): return False
116          def login(self, u, p): pass
117          def send_message(self, msg): pass
118  
119      class _FakePlainSMTP:
120          def __init__(self, *a, **kw):
121              raise AssertionError("plain SMTP must not be used on port 465")
122  
123      with patch("restai.database.get_db_wrapper", return_value=db), \
124           patch("smtplib.SMTP_SSL", _FakeSSL), \
125           patch("smtplib.SMTP", _FakePlainSMTP):
126          out = send_email("subj", "body", _brain=object(), _project_id=1)
127  
128      assert out.startswith("OK:"), out
129      assert used_ssl["flag"] is True
130  
131  
132  def test_send_email_smtp_failure_returns_error_string():
133      from restai.llms.tools.send_email import send_email
134      opts = {
135          "smtp_host": "smtp.example.com",
136          "smtp_port": 587,
137          "smtp_from": "bot@example.com",
138          "email_default_to": "admin@example.com",
139      }
140      db = _fake_db(_fake_project(opts))
141  
142      def _raise(*a, **kw):
143          raise smtplib.SMTPException("relay refused")
144  
145      with patch("restai.database.get_db_wrapper", return_value=db), \
146           patch("smtplib.SMTP", _raise):
147          out = send_email("subj", "body", _brain=object(), _project_id=1)
148      assert out.startswith("ERROR:"), out
149      assert "relay refused" in out
150  
151  
152  # ─── send_sms ───────────────────────────────────────────────────────────
153  
154  def test_send_sms_requires_project_context():
155      from restai.llms.tools.send_sms import send_sms
156      out = send_sms("hi")
157      assert out.startswith("ERROR:")
158      assert "project context" in out
159  
160  
161  def test_send_sms_missing_config():
162      from restai.llms.tools.send_sms import send_sms
163      db = _fake_db(_fake_project({}))
164      with patch("restai.database.get_db_wrapper", return_value=db):
165          out = send_sms("hi", _brain=object(), _project_id=1)
166      assert out.startswith("ERROR:")
167      assert "not configured" in out
168  
169  
170  def test_send_sms_missing_recipient():
171      from restai.llms.tools.send_sms import send_sms
172      db = _fake_db(_fake_project({
173          "twilio_account_sid": "AC123",
174          "twilio_auth_token": encrypt_field("tok"),
175          "twilio_from_number": "+15551234567",
176      }))
177      with patch("restai.database.get_db_wrapper", return_value=db):
178          out = send_sms("hi", _brain=object(), _project_id=1)
179      assert out.startswith("ERROR:")
180      assert "recipient" in out
181  
182  
183  def test_send_sms_happy_path():
184      from restai.llms.tools.send_sms import send_sms
185      db = _fake_db(_fake_project({
186          "twilio_account_sid": "AC123",
187          "twilio_auth_token": encrypt_field("tok_secret"),
188          "twilio_from_number": "+15551234567",
189          "sms_default_to": "+351912345678",
190      }))
191  
192      captured = {}
193      class _FakeResp:
194          status_code = 201
195          def json(self): return {"sid": "SMxxxx"}
196      def fake_post(url, auth=None, data=None, timeout=None):
197          captured.update({"url": url, "auth": auth, "data": data, "timeout": timeout})
198          return _FakeResp()
199  
200      with patch("restai.database.get_db_wrapper", return_value=db), \
201           patch("requests.post", fake_post):
202          out = send_sms("hi from agent", _brain=object(), _project_id=1)
203  
204      assert out.startswith("OK:"), out
205      assert "SMxxxx" in out
206      assert "AC123/Messages.json" in captured["url"]
207      assert captured["auth"] == ("AC123", "tok_secret")
208      assert captured["data"] == {"From": "+15551234567", "To": "+351912345678", "Body": "hi from agent"}
209      assert captured["timeout"] is not None
210  
211  
212  def test_send_sms_chunks_long_messages():
213      from restai.llms.tools.send_sms import send_sms
214      db = _fake_db(_fake_project({
215          "twilio_account_sid": "AC123",
216          "twilio_auth_token": encrypt_field("tok"),
217          "twilio_from_number": "+15551234567",
218          "sms_default_to": "+351912345678",
219      }))
220  
221      calls = []
222      class _FakeResp:
223          status_code = 201
224          def json(self): return {"sid": "SMxxxx"}
225      def fake_post(url, auth=None, data=None, timeout=None):
226          calls.append(len(data["Body"]))
227          return _FakeResp()
228  
229      with patch("restai.database.get_db_wrapper", return_value=db), \
230           patch("requests.post", fake_post):
231          out = send_sms("x" * 3500, _brain=object(), _project_id=1)
232      assert out.startswith("OK:")
233      # 1600-char chunks → 3 parts (1600 + 1600 + 300).
234      assert calls == [1600, 1600, 300]
235  
236  
237  def test_send_sms_twilio_error_surfaces():
238      from restai.llms.tools.send_sms import send_sms
239      db = _fake_db(_fake_project({
240          "twilio_account_sid": "AC123",
241          "twilio_auth_token": encrypt_field("tok"),
242          "twilio_from_number": "+15551234567",
243          "sms_default_to": "+351912345678",
244      }))
245  
246      class _FakeResp:
247          status_code = 400
248          text = '{"message": "from number not owned"}'
249          def json(self): return {"message": "from number not owned", "code": 21603}
250      with patch("restai.database.get_db_wrapper", return_value=db), \
251           patch("requests.post", lambda *a, **kw: _FakeResp()):
252          out = send_sms("hi", _brain=object(), _project_id=1)
253      assert out.startswith("ERROR:")
254      assert "from number not owned" in out