/ tests / gateway / test_usage_command.py
test_usage_command.py
  1  """Tests for gateway /usage command — agent cache lookup and output fields."""
  2  
  3  import asyncio
  4  import threading
  5  from unittest.mock import MagicMock, patch
  6  
  7  import pytest
  8  
  9  
 10  def _make_mock_agent(**overrides):
 11      """Create a mock AIAgent with realistic session counters."""
 12      agent = MagicMock()
 13      defaults = {
 14          "model": "anthropic/claude-sonnet-4.6",
 15          "provider": "openrouter",
 16          "base_url": None,
 17          "session_total_tokens": 50_000,
 18          "session_api_calls": 5,
 19          "session_prompt_tokens": 40_000,
 20          "session_completion_tokens": 10_000,
 21          "session_input_tokens": 35_000,
 22          "session_output_tokens": 10_000,
 23          "session_cache_read_tokens": 5_000,
 24          "session_cache_write_tokens": 2_000,
 25      }
 26      defaults.update(overrides)
 27      for k, v in defaults.items():
 28          setattr(agent, k, v)
 29  
 30      # Rate limit state
 31      rl = MagicMock()
 32      rl.has_data = True
 33      agent.get_rate_limit_state.return_value = rl
 34  
 35      # Context compressor
 36      ctx = MagicMock()
 37      ctx.last_prompt_tokens = 30_000
 38      ctx.context_length = 200_000
 39      ctx.compression_count = 1
 40      agent.context_compressor = ctx
 41  
 42      return agent
 43  
 44  
 45  def _make_runner(session_key, agent=None, cached_agent=None):
 46      """Build a bare GatewayRunner with just the fields _handle_usage_command needs."""
 47      from gateway.run import GatewayRunner, _AGENT_PENDING_SENTINEL
 48  
 49      runner = object.__new__(GatewayRunner)
 50      runner._running_agents = {}
 51      runner._running_agents_ts = {}
 52      runner._agent_cache = {}
 53      runner._agent_cache_lock = threading.Lock()
 54      runner.session_store = MagicMock()
 55  
 56      if agent is not None:
 57          runner._running_agents[session_key] = agent
 58  
 59      if cached_agent is not None:
 60          runner._agent_cache[session_key] = (cached_agent, "sig")
 61  
 62      # Wire helper
 63      runner._session_key_for_source = MagicMock(return_value=session_key)
 64  
 65      return runner
 66  
 67  
 68  SK = "agent:main:telegram:private:12345"
 69  
 70  
 71  class TestUsageCachedAgent:
 72      """The main fix: /usage should find agents in _agent_cache between turns."""
 73  
 74      @pytest.mark.asyncio
 75      async def test_cached_agent_shows_detailed_usage(self):
 76          agent = _make_mock_agent()
 77          runner = _make_runner(SK, cached_agent=agent)
 78          event = MagicMock()
 79  
 80          with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \
 81               patch("agent.usage_pricing.estimate_usage_cost") as mock_cost:
 82              mock_cost.return_value = MagicMock(amount_usd=0.1234, status="estimated")
 83              result = await runner._handle_usage_command(event)
 84  
 85          assert "claude-sonnet-4.6" in result
 86          assert "35,000" in result  # input tokens
 87          assert "10,000" in result  # output tokens
 88          assert "5,000" in result   # cache read
 89          assert "2,000" in result   # cache write
 90          assert "50,000" in result  # total
 91          assert "$0.1234" in result
 92          assert "30,000" in result  # context
 93          assert "Compressions: 1" in result
 94  
 95      @pytest.mark.asyncio
 96      async def test_running_agent_preferred_over_cache(self):
 97          """When agent is in both dicts, the running one wins."""
 98          running = _make_mock_agent(session_api_calls=10, session_total_tokens=80_000)
 99          cached = _make_mock_agent(session_api_calls=5, session_total_tokens=50_000)
100          runner = _make_runner(SK, agent=running, cached_agent=cached)
101          event = MagicMock()
102  
103          with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \
104               patch("agent.usage_pricing.estimate_usage_cost") as mock_cost:
105              mock_cost.return_value = MagicMock(amount_usd=None, status="unknown")
106              result = await runner._handle_usage_command(event)
107  
108          assert "80,000" in result   # running agent's total
109          assert "API calls: 10" in result
110  
111      @pytest.mark.asyncio
112      async def test_sentinel_skipped_uses_cache(self):
113          """PENDING sentinel in _running_agents should fall through to cache."""
114          from gateway.run import _AGENT_PENDING_SENTINEL
115  
116          cached = _make_mock_agent()
117          runner = _make_runner(SK, cached_agent=cached)
118          runner._running_agents[SK] = _AGENT_PENDING_SENTINEL
119          event = MagicMock()
120  
121          with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \
122               patch("agent.usage_pricing.estimate_usage_cost") as mock_cost:
123              mock_cost.return_value = MagicMock(amount_usd=None, status="unknown")
124              result = await runner._handle_usage_command(event)
125  
126          assert "claude-sonnet-4.6" in result
127          assert "Session Token Usage" in result
128  
129      @pytest.mark.asyncio
130      async def test_no_agent_anywhere_falls_to_history(self):
131          """No running or cached agent → rough estimate from transcript."""
132          runner = _make_runner(SK)
133          event = MagicMock()
134  
135          session_entry = MagicMock()
136          session_entry.session_id = "sess123"
137          runner.session_store.get_or_create_session.return_value = session_entry
138          runner.session_store.load_transcript.return_value = [
139              {"role": "user", "content": "hello"},
140              {"role": "assistant", "content": "hi there"},
141          ]
142  
143          with patch("agent.model_metadata.estimate_messages_tokens_rough", return_value=500):
144              result = await runner._handle_usage_command(event)
145  
146          assert "Session Info" in result
147          assert "Messages: 2" in result
148          assert "~500" in result
149  
150      @pytest.mark.asyncio
151      async def test_cache_read_write_hidden_when_zero(self):
152          """Cache token lines should be omitted when zero."""
153          agent = _make_mock_agent(session_cache_read_tokens=0, session_cache_write_tokens=0)
154          runner = _make_runner(SK, cached_agent=agent)
155          event = MagicMock()
156  
157          with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \
158               patch("agent.usage_pricing.estimate_usage_cost") as mock_cost:
159              mock_cost.return_value = MagicMock(amount_usd=None, status="unknown")
160              result = await runner._handle_usage_command(event)
161  
162          assert "Cache read" not in result
163          assert "Cache write" not in result
164  
165      @pytest.mark.asyncio
166      async def test_cost_included_status(self):
167          """Subscription-included providers show 'included' instead of dollar amount."""
168          agent = _make_mock_agent(provider="openai-codex")
169          runner = _make_runner(SK, cached_agent=agent)
170          event = MagicMock()
171  
172          with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \
173               patch("agent.usage_pricing.estimate_usage_cost") as mock_cost:
174              mock_cost.return_value = MagicMock(amount_usd=None, status="included")
175              result = await runner._handle_usage_command(event)
176  
177          assert "Cost: included" in result
178  
179  
180  class TestUsageAccountSection:
181      """Account-limits section appended to /usage output (PR #2486)."""
182  
183      @pytest.mark.asyncio
184      async def test_usage_command_includes_account_section(self, monkeypatch):
185          agent = _make_mock_agent(provider="openai-codex")
186          agent.base_url = "https://chatgpt.com/backend-api/codex"
187          agent.api_key = "unused"
188          runner = _make_runner(SK, cached_agent=agent)
189          event = MagicMock()
190  
191          monkeypatch.setattr(
192              "gateway.run.fetch_account_usage",
193              lambda provider, base_url=None, api_key=None: object(),
194          )
195          monkeypatch.setattr(
196              "gateway.run.render_account_usage_lines",
197              lambda snapshot, markdown=False: [
198                  "📈 **Account limits**",
199                  "Provider: openai-codex (Pro)",
200                  "Session: 85% remaining (15% used)",
201              ],
202          )
203          with patch("agent.rate_limit_tracker.format_rate_limit_compact", return_value="RPM: 50/60"), \
204               patch("agent.usage_pricing.estimate_usage_cost") as mock_cost:
205              mock_cost.return_value = MagicMock(amount_usd=None, status="included")
206              result = await runner._handle_usage_command(event)
207  
208          assert "📊 **Session Token Usage**" in result
209          assert "📈 **Account limits**" in result
210          assert "Provider: openai-codex (Pro)" in result
211  
212      @pytest.mark.asyncio
213      async def test_usage_command_uses_persisted_provider_when_agent_not_running(self, monkeypatch):
214          runner = _make_runner(SK)
215          runner._session_db = MagicMock()
216          runner._session_db.get_session.return_value = {
217              "billing_provider": "openai-codex",
218              "billing_base_url": "https://chatgpt.com/backend-api/codex",
219          }
220          session_entry = MagicMock()
221          session_entry.session_id = "sess-1"
222          runner.session_store.get_or_create_session.return_value = session_entry
223          runner.session_store.load_transcript.return_value = [
224              {"role": "user", "content": "earlier"},
225          ]
226  
227          calls = {}
228  
229          async def _fake_to_thread(fn, *args, **kwargs):
230              calls["args"] = args
231              calls["kwargs"] = kwargs
232              return fn(*args, **kwargs)
233  
234          monkeypatch.setattr("gateway.run.asyncio.to_thread", _fake_to_thread)
235          monkeypatch.setattr(
236              "gateway.run.fetch_account_usage",
237              lambda provider, base_url=None, api_key=None: object(),
238          )
239          monkeypatch.setattr(
240              "gateway.run.render_account_usage_lines",
241              lambda snapshot, markdown=False: [
242                  "📈 **Account limits**",
243                  "Provider: openai-codex (Pro)",
244              ],
245          )
246  
247          event = MagicMock()
248          result = await runner._handle_usage_command(event)
249  
250          assert calls["args"] == ("openai-codex",)
251          assert calls["kwargs"]["base_url"] == "https://chatgpt.com/backend-api/codex"
252          assert "📊 **Session Info**" in result
253          assert "📈 **Account limits**" in result