test_yuanbao_pipeline.py
1 """ 2 test_yuanbao_pipeline.py - Unit tests for the inbound middleware pipeline. 3 4 Tests cover: 5 1. InboundPipeline engine (use, use_before, use_after, remove, execute) 6 2. InboundContext dataclass 7 3. Individual middlewares (DecodeMiddleware, DedupMiddleware, SkipSelfMiddleware, etc.) 8 4. InboundPipelineBuilder 9 5. End-to-end pipeline integration 10 6. OOP middleware ABC and class tests 11 """ 12 13 import sys 14 import os 15 import json 16 import asyncio 17 18 # Ensure project root is on the path 19 _REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 20 if _REPO_ROOT not in sys.path: 21 sys.path.insert(0, _REPO_ROOT) 22 23 import pytest 24 from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock 25 26 from gateway.platforms.yuanbao import ( 27 InboundContext, 28 InboundMiddleware, 29 InboundPipeline, 30 DecodeMiddleware, 31 ExtractFieldsMiddleware, 32 DedupMiddleware, 33 SkipSelfMiddleware, 34 ChatRoutingMiddleware, 35 AccessPolicy, 36 AccessGuardMiddleware, 37 ExtractContentMiddleware, 38 PlaceholderFilterMiddleware, 39 OwnerCommandMiddleware, 40 BuildSourceMiddleware, 41 GroupAtGuardMiddleware, 42 DispatchMiddleware, 43 InboundPipelineBuilder, 44 YuanbaoAdapter, 45 ) 46 from gateway.config import Platform, PlatformConfig 47 48 49 # ============================================================ 50 # Helpers 51 # ============================================================ 52 53 def make_config(**kwargs): 54 extra = kwargs.pop("extra", {}) 55 extra.setdefault("app_id", "test_key") 56 extra.setdefault("app_secret", "test_secret") 57 extra.setdefault("ws_url", "wss://test.example.com/ws") 58 extra.setdefault("api_domain", "https://test.example.com") 59 return PlatformConfig( 60 extra=extra, 61 **kwargs, 62 ) 63 64 65 def make_adapter(**kwargs) -> YuanbaoAdapter: 66 """Create a YuanbaoAdapter with test config.""" 67 config = make_config(**kwargs) 68 adapter = YuanbaoAdapter(config) 69 adapter._bot_id = "bot_123" 70 return adapter 71 72 73 def make_ctx(adapter=None, conn_data=b"", **overrides) -> InboundContext: 74 """Create an InboundContext with sensible defaults for testing.""" 75 if adapter is None: 76 adapter = make_adapter() 77 raw_frames = [conn_data] if conn_data else [] 78 ctx = InboundContext(adapter=adapter, raw_frames=raw_frames) 79 for k, v in overrides.items(): 80 setattr(ctx, k, v) 81 return ctx 82 83 84 def make_json_push( 85 from_account="alice", 86 to_account="bot_123", 87 group_code="", 88 text="Hello!", 89 msg_id="msg-001", 90 ) -> bytes: 91 """Build a JSON callback_command push payload. 92 93 Note: MsgContent inner fields use lowercase ("text" not "Text") 94 because _extract_text() looks for lowercase keys. 95 """ 96 msg_body = [{"MsgType": "TIMTextElem", "MsgContent": {"text": text}}] 97 push = { 98 "CallbackCommand": "C2C.CallbackAfterSendMsg", 99 "From_Account": from_account, 100 "To_Account": to_account, 101 "MsgBody": msg_body, 102 "MsgKey": msg_id, 103 } 104 if group_code: 105 push["CallbackCommand"] = "Group.CallbackAfterSendMsg" 106 push["GroupId"] = group_code 107 return json.dumps(push).encode("utf-8") 108 109 110 # ============================================================ 111 # 1. InboundPipeline Engine Tests 112 # ============================================================ 113 114 class TestInboundPipeline: 115 """Test the pipeline engine itself.""" 116 117 @pytest.mark.asyncio 118 async def test_empty_pipeline(self): 119 """Empty pipeline executes without error.""" 120 pipeline = InboundPipeline() 121 ctx = make_ctx() 122 await pipeline.execute(ctx) # Should not raise 123 124 @pytest.mark.asyncio 125 async def test_single_middleware(self): 126 """Single middleware is called with ctx and next_fn.""" 127 called = [] 128 129 async def mw(ctx, next_fn): 130 called.append("mw") 131 await next_fn() 132 133 pipeline = InboundPipeline().use("test", mw) 134 ctx = make_ctx() 135 await pipeline.execute(ctx) 136 assert called == ["mw"] 137 138 @pytest.mark.asyncio 139 async def test_middleware_order(self): 140 """Middlewares execute in registration order.""" 141 order = [] 142 143 async def mw_a(ctx, next_fn): 144 order.append("a") 145 await next_fn() 146 147 async def mw_b(ctx, next_fn): 148 order.append("b") 149 await next_fn() 150 151 async def mw_c(ctx, next_fn): 152 order.append("c") 153 await next_fn() 154 155 pipeline = InboundPipeline().use("a", mw_a).use("b", mw_b).use("c", mw_c) 156 await pipeline.execute(make_ctx()) 157 assert order == ["a", "b", "c"] 158 159 @pytest.mark.asyncio 160 async def test_middleware_can_stop_pipeline(self): 161 """A middleware that doesn't call next_fn stops the pipeline.""" 162 order = [] 163 164 async def mw_stop(ctx, next_fn): 165 order.append("stop") 166 # Don't call next_fn — pipeline stops here 167 168 async def mw_after(ctx, next_fn): 169 order.append("after") 170 await next_fn() 171 172 pipeline = InboundPipeline().use("stop", mw_stop).use("after", mw_after) 173 await pipeline.execute(make_ctx()) 174 assert order == ["stop"] # "after" should NOT be called 175 176 @pytest.mark.asyncio 177 async def test_conditional_guard_skip(self): 178 """Middleware with when=False is skipped.""" 179 order = [] 180 181 async def mw_a(ctx, next_fn): 182 order.append("a") 183 await next_fn() 184 185 async def mw_skipped(ctx, next_fn): 186 order.append("skipped") 187 await next_fn() 188 189 async def mw_c(ctx, next_fn): 190 order.append("c") 191 await next_fn() 192 193 pipeline = ( 194 InboundPipeline() 195 .use("a", mw_a) 196 .use("skipped", mw_skipped, when=lambda ctx: False) 197 .use("c", mw_c) 198 ) 199 await pipeline.execute(make_ctx()) 200 assert order == ["a", "c"] 201 202 @pytest.mark.asyncio 203 async def test_conditional_guard_pass(self): 204 """Middleware with when=True is executed.""" 205 order = [] 206 207 async def mw(ctx, next_fn): 208 order.append("mw") 209 await next_fn() 210 211 pipeline = InboundPipeline().use("mw", mw, when=lambda ctx: True) 212 await pipeline.execute(make_ctx()) 213 assert order == ["mw"] 214 215 def test_use_before(self): 216 """use_before inserts middleware before the target.""" 217 async def noop(ctx, next_fn): 218 await next_fn() 219 220 pipeline = InboundPipeline().use("a", noop).use("c", noop) 221 pipeline.use_before("c", "b", noop) 222 assert pipeline.middleware_names == ["a", "b", "c"] 223 224 def test_use_before_nonexistent_appends(self): 225 """use_before with nonexistent target appends to end.""" 226 async def noop(ctx, next_fn): 227 await next_fn() 228 229 pipeline = InboundPipeline().use("a", noop) 230 pipeline.use_before("nonexistent", "b", noop) 231 assert pipeline.middleware_names == ["a", "b"] 232 233 def test_use_after(self): 234 """use_after inserts middleware after the target.""" 235 async def noop(ctx, next_fn): 236 await next_fn() 237 238 pipeline = InboundPipeline().use("a", noop).use("c", noop) 239 pipeline.use_after("a", "b", noop) 240 assert pipeline.middleware_names == ["a", "b", "c"] 241 242 def test_use_after_nonexistent_appends(self): 243 """use_after with nonexistent target appends to end.""" 244 async def noop(ctx, next_fn): 245 await next_fn() 246 247 pipeline = InboundPipeline().use("a", noop) 248 pipeline.use_after("nonexistent", "b", noop) 249 assert pipeline.middleware_names == ["a", "b"] 250 251 def test_remove(self): 252 """remove deletes middleware by name.""" 253 async def noop(ctx, next_fn): 254 await next_fn() 255 256 pipeline = InboundPipeline().use("a", noop).use("b", noop).use("c", noop) 257 pipeline.remove("b") 258 assert pipeline.middleware_names == ["a", "c"] 259 260 def test_remove_nonexistent_is_noop(self): 261 """remove with nonexistent name is a no-op.""" 262 async def noop(ctx, next_fn): 263 await next_fn() 264 265 pipeline = InboundPipeline().use("a", noop) 266 pipeline.remove("nonexistent") 267 assert pipeline.middleware_names == ["a"] 268 269 @pytest.mark.asyncio 270 async def test_error_propagation(self): 271 """Errors in middlewares propagate to the caller.""" 272 async def mw_error(ctx, next_fn): 273 raise ValueError("test error") 274 275 pipeline = InboundPipeline().use("error", mw_error) 276 with pytest.raises(ValueError, match="test error"): 277 await pipeline.execute(make_ctx()) 278 279 def test_middleware_names_property(self): 280 """middleware_names returns ordered list of names.""" 281 async def noop(ctx, next_fn): 282 await next_fn() 283 284 pipeline = ( 285 InboundPipeline() 286 .use("decode", noop) 287 .use("dedup", noop) 288 .use("dispatch", noop) 289 ) 290 assert pipeline.middleware_names == ["decode", "dedup", "dispatch"] 291 292 @pytest.mark.asyncio 293 async def test_onion_model(self): 294 """Middlewares support before/after processing (onion model).""" 295 order = [] 296 297 async def mw_outer(ctx, next_fn): 298 order.append("outer-before") 299 await next_fn() 300 order.append("outer-after") 301 302 async def mw_inner(ctx, next_fn): 303 order.append("inner") 304 await next_fn() 305 306 pipeline = InboundPipeline().use("outer", mw_outer).use("inner", mw_inner) 307 await pipeline.execute(make_ctx()) 308 assert order == ["outer-before", "inner", "outer-after"] 309 310 311 # ============================================================ 312 # 2. InboundContext Tests 313 # ============================================================ 314 315 class TestInboundContext: 316 def test_default_values(self): 317 """InboundContext has sensible defaults.""" 318 adapter = make_adapter() 319 ctx = InboundContext(adapter=adapter) 320 assert ctx.raw_frames == [] 321 assert ctx.push is None 322 assert ctx.decoded_via == "" 323 assert ctx.from_account == "" 324 assert ctx.group_code == "" 325 assert ctx.msg_body == [] 326 assert ctx.msg_id == "" 327 assert ctx.chat_id == "" 328 assert ctx.chat_type == "" 329 assert ctx.raw_text == "" 330 assert ctx.media_refs == [] 331 assert ctx.owner_command is None 332 assert ctx.source is None 333 assert ctx.msg_type is None 334 335 def test_mutable_fields(self): 336 """InboundContext fields are mutable.""" 337 ctx = make_ctx() 338 ctx.from_account = "alice" 339 ctx.chat_type = "dm" 340 assert ctx.from_account == "alice" 341 assert ctx.chat_type == "dm" 342 343 344 # ============================================================ 345 # 3. Individual Middleware Tests 346 # ============================================================ 347 348 class TestDecodeMiddleware: 349 @pytest.mark.asyncio 350 async def test_json_decode(self): 351 """DecodeMiddleware parses JSON push correctly.""" 352 push_data = make_json_push(from_account="alice", text="hi") 353 ctx = make_ctx(conn_data=push_data) 354 next_fn = AsyncMock() 355 356 await DecodeMiddleware()(ctx, next_fn) 357 358 assert ctx.push is not None 359 assert ctx.decoded_via == "json" 360 assert ctx.push.get("from_account") == "alice" 361 next_fn.assert_awaited_once() 362 363 @pytest.mark.asyncio 364 async def test_empty_data_stops_pipeline(self): 365 """DecodeMiddleware stops pipeline on empty conn_data.""" 366 ctx = make_ctx(conn_data=b"") 367 next_fn = AsyncMock() 368 369 await DecodeMiddleware()(ctx, next_fn) 370 371 assert ctx.push is None 372 next_fn.assert_not_awaited() 373 374 @pytest.mark.asyncio 375 async def test_invalid_data_may_produce_garbage(self): 376 """DecodeMiddleware: binary data may be parsed by protobuf as garbage fields. 377 378 This is expected behavior — the protobuf parser is lenient and may 379 produce "seemingly valid" fields from arbitrary bytes. The downstream 380 middlewares (dedup, skip-self, etc.) will filter out such garbage. 381 """ 382 ctx = make_ctx(conn_data=b"\x00\x01\x02\x03") 383 next_fn = AsyncMock() 384 385 await DecodeMiddleware()(ctx, next_fn) 386 387 # Protobuf parser may or may not produce a result — either is acceptable. 388 # The key invariant: no exception is raised. 389 assert True # Reached here without error 390 391 392 class TestExtractFieldsMiddleware: 393 @pytest.mark.asyncio 394 async def test_extracts_fields(self): 395 """ExtractFieldsMiddleware populates ctx from push dict.""" 396 ctx = make_ctx(push={ 397 "from_account": "alice", 398 "group_code": "grp-1", 399 "group_name": "Test Group", 400 "sender_nickname": "Alice", 401 "msg_body": [{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}], 402 "msg_id": "msg-001", 403 "cloud_custom_data": '{"key": "val"}', 404 }) 405 next_fn = AsyncMock() 406 407 await ExtractFieldsMiddleware()(ctx, next_fn) 408 409 assert ctx.from_account == "alice" 410 assert ctx.group_code == "grp-1" 411 assert ctx.group_name == "Test Group" 412 assert ctx.sender_nickname == "Alice" 413 assert len(ctx.msg_body) == 1 414 assert ctx.msg_id == "msg-001" 415 assert ctx.cloud_custom_data == '{"key": "val"}' 416 next_fn.assert_awaited_once() 417 418 419 class TestDedupMiddleware: 420 @pytest.mark.asyncio 421 async def test_new_message_passes(self): 422 """DedupMiddleware passes new messages through.""" 423 adapter = make_adapter() 424 ctx = make_ctx(adapter=adapter, msg_id="unique-msg-001") 425 next_fn = AsyncMock() 426 427 await DedupMiddleware()(ctx, next_fn) 428 next_fn.assert_awaited_once() 429 430 @pytest.mark.asyncio 431 async def test_duplicate_stops_pipeline(self): 432 """DedupMiddleware stops pipeline for duplicate messages.""" 433 adapter = make_adapter() 434 # Mark message as seen 435 adapter._dedup.is_duplicate("dup-msg-001") 436 437 ctx = make_ctx(adapter=adapter, msg_id="dup-msg-001") 438 next_fn = AsyncMock() 439 440 await DedupMiddleware()(ctx, next_fn) 441 next_fn.assert_not_awaited() 442 443 @pytest.mark.asyncio 444 async def test_empty_msg_id_passes(self): 445 """DedupMiddleware passes messages with empty msg_id.""" 446 ctx = make_ctx(msg_id="") 447 next_fn = AsyncMock() 448 449 await DedupMiddleware()(ctx, next_fn) 450 next_fn.assert_awaited_once() 451 452 453 class TestSkipSelfMiddleware: 454 @pytest.mark.asyncio 455 async def test_self_message_stops(self): 456 """SkipSelfMiddleware stops pipeline for bot's own messages.""" 457 adapter = make_adapter() 458 adapter._bot_id = "bot_123" 459 ctx = make_ctx(adapter=adapter, from_account="bot_123") 460 next_fn = AsyncMock() 461 462 await SkipSelfMiddleware()(ctx, next_fn) 463 next_fn.assert_not_awaited() 464 465 @pytest.mark.asyncio 466 async def test_other_message_passes(self): 467 """SkipSelfMiddleware passes messages from other users.""" 468 adapter = make_adapter() 469 adapter._bot_id = "bot_123" 470 ctx = make_ctx(adapter=adapter, from_account="alice") 471 next_fn = AsyncMock() 472 473 await SkipSelfMiddleware()(ctx, next_fn) 474 next_fn.assert_awaited_once() 475 476 477 class TestChatRoutingMiddleware: 478 @pytest.mark.asyncio 479 async def test_group_routing(self): 480 """ChatRoutingMiddleware sets group chat fields.""" 481 ctx = make_ctx(group_code="grp-1", group_name="Test Group") 482 next_fn = AsyncMock() 483 484 await ChatRoutingMiddleware()(ctx, next_fn) 485 486 assert ctx.chat_id == "group:grp-1" 487 assert ctx.chat_type == "group" 488 assert ctx.chat_name == "Test Group" 489 next_fn.assert_awaited_once() 490 491 @pytest.mark.asyncio 492 async def test_dm_routing(self): 493 """ChatRoutingMiddleware sets DM chat fields.""" 494 ctx = make_ctx(from_account="alice", sender_nickname="Alice") 495 next_fn = AsyncMock() 496 497 await ChatRoutingMiddleware()(ctx, next_fn) 498 499 assert ctx.chat_id == "direct:alice" 500 assert ctx.chat_type == "dm" 501 assert ctx.chat_name == "Alice" 502 next_fn.assert_awaited_once() 503 504 @pytest.mark.asyncio 505 async def test_dm_routing_no_nickname(self): 506 """ChatRoutingMiddleware falls back to from_account when no nickname.""" 507 ctx = make_ctx(from_account="alice", sender_nickname="") 508 next_fn = AsyncMock() 509 510 await ChatRoutingMiddleware()(ctx, next_fn) 511 512 assert ctx.chat_name == "alice" 513 514 515 class TestAccessGuardMiddleware: 516 @pytest.mark.asyncio 517 async def test_open_policy_passes(self): 518 """AccessGuardMiddleware passes with open policy.""" 519 adapter = make_adapter() 520 adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[]) 521 ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice") 522 next_fn = AsyncMock() 523 524 await AccessGuardMiddleware()(ctx, next_fn) 525 next_fn.assert_awaited_once() 526 527 @pytest.mark.asyncio 528 async def test_disabled_dm_stops(self): 529 """AccessGuardMiddleware stops DM when dm_policy=disabled.""" 530 adapter = make_adapter() 531 adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[]) 532 ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice") 533 next_fn = AsyncMock() 534 535 await AccessGuardMiddleware()(ctx, next_fn) 536 next_fn.assert_not_awaited() 537 538 @pytest.mark.asyncio 539 async def test_allowlist_dm_allowed(self): 540 """AccessGuardMiddleware passes DM when sender is in allowlist.""" 541 adapter = make_adapter() 542 adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["alice"], group_policy="open", group_allow_from=[]) 543 ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice") 544 next_fn = AsyncMock() 545 546 await AccessGuardMiddleware()(ctx, next_fn) 547 next_fn.assert_awaited_once() 548 549 @pytest.mark.asyncio 550 async def test_allowlist_dm_blocked(self): 551 """AccessGuardMiddleware blocks DM when sender is not in allowlist.""" 552 adapter = make_adapter() 553 adapter._access_policy = AccessPolicy(dm_policy="allowlist", dm_allow_from=["bob"], group_policy="open", group_allow_from=[]) 554 ctx = make_ctx(adapter=adapter, chat_type="dm", from_account="alice") 555 next_fn = AsyncMock() 556 557 await AccessGuardMiddleware()(ctx, next_fn) 558 next_fn.assert_not_awaited() 559 560 @pytest.mark.asyncio 561 async def test_disabled_group_stops(self): 562 """AccessGuardMiddleware stops group when group_policy=disabled.""" 563 adapter = make_adapter() 564 adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="disabled", group_allow_from=[]) 565 ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1") 566 next_fn = AsyncMock() 567 568 await AccessGuardMiddleware()(ctx, next_fn) 569 next_fn.assert_not_awaited() 570 571 @pytest.mark.asyncio 572 async def test_allowlist_group_allowed(self): 573 """AccessGuardMiddleware passes group when group_code is in allowlist.""" 574 adapter = make_adapter() 575 adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="allowlist", group_allow_from=["grp-1"]) 576 ctx = make_ctx(adapter=adapter, chat_type="group", group_code="grp-1") 577 next_fn = AsyncMock() 578 579 await AccessGuardMiddleware()(ctx, next_fn) 580 next_fn.assert_awaited_once() 581 582 583 class TestExtractContentMiddleware: 584 @pytest.mark.asyncio 585 async def test_extracts_text_and_media(self): 586 """ExtractContentMiddleware extracts text and media refs.""" 587 adapter = make_adapter() 588 msg_body = [ 589 {"msg_type": "TIMTextElem", "msg_content": {"text": "Hello!"}}, 590 {"msg_type": "TIMImageElem", "msg_content": { 591 "image_info_array": [{"url": "https://img.example.com/1.jpg"}] 592 }}, 593 ] 594 ctx = make_ctx(adapter=adapter, msg_body=msg_body) 595 next_fn = AsyncMock() 596 597 await ExtractContentMiddleware()(ctx, next_fn) 598 599 assert "Hello!" in ctx.raw_text 600 assert len(ctx.media_refs) == 1 601 assert ctx.media_refs[0]["kind"] == "image" 602 next_fn.assert_awaited_once() 603 604 605 class TestPlaceholderFilterMiddleware: 606 @pytest.mark.asyncio 607 async def test_placeholder_stops(self): 608 """PlaceholderFilterMiddleware stops on pure placeholder.""" 609 ctx = make_ctx(raw_text="[image]", media_refs=[]) 610 next_fn = AsyncMock() 611 612 await PlaceholderFilterMiddleware()(ctx, next_fn) 613 next_fn.assert_not_awaited() 614 615 @pytest.mark.asyncio 616 async def test_placeholder_with_media_passes(self): 617 """PlaceholderFilterMiddleware passes placeholder when media exists.""" 618 ctx = make_ctx( 619 raw_text="[image]", 620 media_refs=[{"kind": "image", "url": "https://img.example.com/1.jpg"}], 621 ) 622 next_fn = AsyncMock() 623 624 await PlaceholderFilterMiddleware()(ctx, next_fn) 625 next_fn.assert_awaited_once() 626 627 @pytest.mark.asyncio 628 async def test_normal_text_passes(self): 629 """PlaceholderFilterMiddleware passes normal text.""" 630 ctx = make_ctx(raw_text="Hello world!") 631 next_fn = AsyncMock() 632 633 await PlaceholderFilterMiddleware()(ctx, next_fn) 634 next_fn.assert_awaited_once() 635 636 637 class TestGroupAtGuardMiddleware: 638 @pytest.mark.asyncio 639 async def test_dm_passes(self): 640 """GroupAtGuardMiddleware passes DM messages.""" 641 adapter = make_adapter() 642 ctx = make_ctx(adapter=adapter, chat_type="dm") 643 next_fn = AsyncMock() 644 645 await GroupAtGuardMiddleware()(ctx, next_fn) 646 next_fn.assert_awaited_once() 647 648 @pytest.mark.asyncio 649 async def test_group_with_at_bot_passes(self): 650 """GroupAtGuardMiddleware passes group messages that @bot.""" 651 adapter = make_adapter() 652 adapter._bot_id = "bot_123" 653 msg_body = [ 654 {"msg_type": "TIMCustomElem", "msg_content": { 655 "data": json.dumps({"elem_type": 1002, "text": "@Bot", "user_id": "bot_123"}) 656 }}, 657 ] 658 ctx = make_ctx( 659 adapter=adapter, 660 chat_type="group", 661 chat_id="group:grp-1", 662 msg_body=msg_body, 663 from_account="alice", 664 sender_nickname="Alice", 665 raw_text="Hello", 666 source=MagicMock(), 667 ) 668 next_fn = AsyncMock() 669 670 await GroupAtGuardMiddleware()(ctx, next_fn) 671 next_fn.assert_awaited_once() 672 673 @pytest.mark.asyncio 674 async def test_group_without_at_bot_observes(self): 675 """GroupAtGuardMiddleware observes group messages without @bot.""" 676 adapter = make_adapter() 677 adapter._bot_id = "bot_123" 678 adapter._session_store = None # No session store -> observe is a no-op 679 ctx = make_ctx( 680 adapter=adapter, 681 chat_type="group", 682 chat_id="group:grp-1", 683 msg_body=[{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}], 684 from_account="alice", 685 sender_nickname="Alice", 686 raw_text="hi", 687 source=MagicMock(), 688 ) 689 next_fn = AsyncMock() 690 691 await GroupAtGuardMiddleware()(ctx, next_fn) 692 693 next_fn.assert_not_awaited() 694 695 @pytest.mark.asyncio 696 async def test_owner_command_skips_at_check(self): 697 """GroupAtGuardMiddleware passes when owner_command is set.""" 698 adapter = make_adapter() 699 adapter._bot_id = "bot_123" 700 ctx = make_ctx( 701 adapter=adapter, 702 chat_type="group", 703 msg_body=[], 704 owner_command="/new", 705 source=MagicMock(), 706 ) 707 next_fn = AsyncMock() 708 709 await GroupAtGuardMiddleware()(ctx, next_fn) 710 next_fn.assert_awaited_once() 711 712 713 # ============================================================ 714 # 4. Factory Tests 715 # ============================================================ 716 717 class TestCreateInboundPipeline: 718 def test_default_pipeline_has_all_middlewares(self): 719 """InboundPipelineBuilder.build() creates pipeline with all expected middlewares.""" 720 pipeline = InboundPipelineBuilder.build() 721 expected = [ 722 "decode", 723 "extract-fields", 724 "dedup", 725 "skip-self", 726 "chat-routing", 727 "access-guard", 728 "extract-content", 729 "placeholder-filter", 730 "owner-command", 731 "build-source", 732 "group-at-guard", 733 "classify-msg-type", 734 "quote-context", 735 "media-resolve", 736 "dispatch", 737 ] 738 """Pipeline can be customized after creation.""" 739 pipeline = InboundPipelineBuilder.build() 740 741 async def custom_mw(ctx, next_fn): 742 await next_fn() 743 744 pipeline.use_before("dispatch", "custom", custom_mw) 745 assert "custom" in pipeline.middleware_names 746 idx_custom = pipeline.middleware_names.index("custom") 747 idx_dispatch = pipeline.middleware_names.index("dispatch") 748 assert idx_custom < idx_dispatch 749 750 751 # ============================================================ 752 # 5. End-to-End Pipeline Integration Tests 753 # ============================================================ 754 755 class TestPipelineIntegration: 756 @pytest.mark.asyncio 757 async def test_full_dm_message_flow(self): 758 """Full pipeline processes a DM message end-to-end.""" 759 adapter = make_adapter() 760 adapter._bot_id = "bot_123" 761 adapter._access_policy = AccessPolicy(dm_policy="open", dm_allow_from=[], group_policy="open", group_allow_from=[]) 762 adapter.handle_message = AsyncMock() 763 adapter._resolve_inbound_media_urls = AsyncMock(return_value=([], [])) 764 765 push_data = make_json_push( 766 from_account="alice", 767 to_account="bot_123", 768 text="Hello bot!", 769 msg_id="msg-e2e-001", 770 ) 771 772 ctx = InboundContext(adapter=adapter, raw_frames=[push_data]) 773 pipeline = InboundPipelineBuilder.build() 774 await pipeline.execute(ctx) 775 776 # Verify context was populated correctly 777 assert ctx.decoded_via == "json" 778 assert ctx.from_account == "alice" 779 assert ctx.chat_type == "dm" 780 assert ctx.chat_id == "direct:alice" 781 assert "Hello bot!" in ctx.raw_text 782 assert ctx.source is not None 783 784 @pytest.mark.asyncio 785 async def test_self_message_filtered(self): 786 """Pipeline stops when message is from bot itself.""" 787 adapter = make_adapter() 788 adapter._bot_id = "bot_123" 789 790 push_data = make_json_push( 791 from_account="bot_123", 792 to_account="bot_123", 793 text="echo", 794 msg_id="msg-self-001", 795 ) 796 797 ctx = InboundContext(adapter=adapter, raw_frames=[push_data]) 798 pipeline = InboundPipelineBuilder.build() 799 await pipeline.execute(ctx) 800 801 # Pipeline should have stopped at skip-self — no source built 802 assert ctx.source is None 803 804 @pytest.mark.asyncio 805 async def test_duplicate_message_filtered(self): 806 """Pipeline stops on duplicate message.""" 807 adapter = make_adapter() 808 adapter._bot_id = "bot_123" 809 810 # First message goes through 811 push_data = make_json_push( 812 from_account="alice", 813 text="Hello!", 814 msg_id="msg-dup-001", 815 ) 816 ctx1 = InboundContext(adapter=adapter, raw_frames=[push_data]) 817 pipeline = InboundPipelineBuilder.build() 818 await pipeline.execute(ctx1) 819 assert ctx1.from_account == "alice" 820 821 # Second message with same msg_id is filtered 822 ctx2 = InboundContext(adapter=adapter, raw_frames=[push_data]) 823 await pipeline.execute(ctx2) 824 # Dedup should stop pipeline before chat routing 825 assert ctx2.chat_type == "" 826 827 @pytest.mark.asyncio 828 async def test_blocked_dm_filtered(self): 829 """Pipeline stops when DM is blocked by policy.""" 830 adapter = make_adapter() 831 adapter._bot_id = "bot_123" 832 adapter._access_policy = AccessPolicy(dm_policy="disabled", dm_allow_from=[], group_policy="open", group_allow_from=[]) 833 834 push_data = make_json_push( 835 from_account="alice", 836 text="Hello!", 837 msg_id="msg-blocked-001", 838 ) 839 840 ctx = InboundContext(adapter=adapter, raw_frames=[push_data]) 841 pipeline = InboundPipelineBuilder.build() 842 await pipeline.execute(ctx) 843 844 # Pipeline stopped at access-guard — no content extracted 845 assert ctx.raw_text == "" 846 847 @pytest.mark.asyncio 848 async def test_adapter_has_pipeline(self): 849 """YuanbaoAdapter.__init__ creates an inbound pipeline.""" 850 adapter = make_adapter() 851 assert hasattr(adapter, "_inbound_pipeline") 852 assert isinstance(adapter._inbound_pipeline, InboundPipeline) 853 854 855 856 if __name__ == "__main__": 857 pytest.main([__file__, "-v"]) 858 859 860 # ============================================================ 861 # 6. OOP Middleware Tests 862 # ============================================================ 863 864 class TestInboundMiddlewareABC: 865 """Test the InboundMiddleware abstract base class.""" 866 867 def test_cannot_instantiate_abc(self): 868 """InboundMiddleware cannot be instantiated directly.""" 869 with pytest.raises(TypeError): 870 InboundMiddleware() 871 872 def test_subclass_must_implement_handle(self): 873 """Subclass without handle() raises TypeError.""" 874 with pytest.raises(TypeError): 875 class BadMiddleware(InboundMiddleware): 876 name = "bad" 877 BadMiddleware() 878 879 def test_subclass_with_handle_works(self): 880 """Subclass with handle() can be instantiated.""" 881 class GoodMiddleware(InboundMiddleware): 882 name = "good" 883 async def handle(self, ctx, next_fn): 884 await next_fn() 885 mw = GoodMiddleware() 886 assert mw.name == "good" 887 888 @pytest.mark.asyncio 889 async def test_callable_protocol(self): 890 """Middleware instances are callable via __call__.""" 891 class TestMW(InboundMiddleware): 892 name = "test" 893 async def handle(self, ctx, next_fn): 894 ctx.raw_text = "called" 895 await next_fn() 896 897 mw = TestMW() 898 ctx = make_ctx() 899 next_fn = AsyncMock() 900 await mw(ctx, next_fn) # Call via __call__ 901 assert ctx.raw_text == "called" 902 next_fn.assert_awaited_once() 903 904 def test_repr(self): 905 """Middleware has a useful repr.""" 906 class MyMW(InboundMiddleware): 907 name = "my-mw" 908 async def handle(self, ctx, next_fn): 909 pass 910 mw = MyMW() 911 assert "MyMW" in repr(mw) 912 assert "my-mw" in repr(mw) 913 914 915 class TestMiddlewareClasses: 916 """Test that all concrete middleware classes have correct names and are InboundMiddleware subclasses.""" 917 918 MIDDLEWARE_CLASSES = [ 919 (DecodeMiddleware, "decode"), 920 (ExtractFieldsMiddleware, "extract-fields"), 921 (DedupMiddleware, "dedup"), 922 (SkipSelfMiddleware, "skip-self"), 923 (ChatRoutingMiddleware, "chat-routing"), 924 (AccessGuardMiddleware, "access-guard"), 925 (ExtractContentMiddleware, "extract-content"), 926 (PlaceholderFilterMiddleware, "placeholder-filter"), 927 (OwnerCommandMiddleware, "owner-command"), 928 (BuildSourceMiddleware, "build-source"), 929 (GroupAtGuardMiddleware, "group-at-guard"), 930 (DispatchMiddleware, "dispatch"), 931 ] 932 933 @pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES) 934 def test_is_inbound_middleware(self, cls, expected_name): 935 """Each middleware class is a subclass of InboundMiddleware.""" 936 assert issubclass(cls, InboundMiddleware) 937 938 @pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES) 939 def test_has_correct_name(self, cls, expected_name): 940 """Each middleware class has the expected name.""" 941 mw = cls() 942 assert mw.name == expected_name 943 944 @pytest.mark.parametrize("cls,expected_name", MIDDLEWARE_CLASSES) 945 def test_is_callable(self, cls, expected_name): 946 """Each middleware instance is callable.""" 947 mw = cls() 948 assert callable(mw) 949 950 951 class TestPipelineOOPRegistration: 952 """Test that InboundPipeline works with OOP middleware instances.""" 953 954 @pytest.mark.asyncio 955 async def test_use_with_middleware_instance(self): 956 """pipeline.use(SomeMiddleware()) auto-extracts name.""" 957 class TestMW(InboundMiddleware): 958 name = "test-mw" 959 async def handle(self, ctx, next_fn): 960 ctx.raw_text = "oop-works" 961 await next_fn() 962 963 pipeline = InboundPipeline().use(TestMW()) 964 assert pipeline.middleware_names == ["test-mw"] 965 966 ctx = make_ctx() 967 await pipeline.execute(ctx) 968 assert ctx.raw_text == "oop-works" 969 970 @pytest.mark.asyncio 971 async def test_mixed_oop_and_functional(self): 972 """Pipeline supports mixing OOP and functional middlewares.""" 973 order = [] 974 975 class OopMW(InboundMiddleware): 976 name = "oop" 977 async def handle(self, ctx, next_fn): 978 order.append("oop") 979 await next_fn() 980 981 async def func_mw(ctx, next_fn): 982 order.append("func") 983 await next_fn() 984 985 pipeline = ( 986 InboundPipeline() 987 .use(OopMW()) 988 .use("func", func_mw) 989 ) 990 assert pipeline.middleware_names == ["oop", "func"] 991 992 await pipeline.execute(make_ctx()) 993 assert order == ["oop", "func"] 994 995 def test_use_before_with_middleware_instance(self): 996 """use_before works with OOP middleware instances.""" 997 class MwA(InboundMiddleware): 998 name = "a" 999 async def handle(self, ctx, next_fn): await next_fn() 1000 1001 class MwB(InboundMiddleware): 1002 name = "b" 1003 async def handle(self, ctx, next_fn): await next_fn() 1004 1005 class MwC(InboundMiddleware): 1006 name = "c" 1007 async def handle(self, ctx, next_fn): await next_fn() 1008 1009 pipeline = InboundPipeline().use(MwA()).use(MwC()) 1010 pipeline.use_before("c", MwB()) 1011 assert pipeline.middleware_names == ["a", "b", "c"] 1012 1013 def test_use_after_with_middleware_instance(self): 1014 """use_after works with OOP middleware instances.""" 1015 class MwA(InboundMiddleware): 1016 name = "a" 1017 async def handle(self, ctx, next_fn): await next_fn() 1018 1019 class MwB(InboundMiddleware): 1020 name = "b" 1021 async def handle(self, ctx, next_fn): await next_fn() 1022 1023 class MwC(InboundMiddleware): 1024 name = "c" 1025 async def handle(self, ctx, next_fn): await next_fn() 1026 1027 pipeline = InboundPipeline().use(MwA()).use(MwC()) 1028 pipeline.use_after("a", MwB()) 1029 assert pipeline.middleware_names == ["a", "b", "c"]