|
1 |
| -from unittest.mock import MagicMock, patch |
| 1 | +from unittest.mock import MagicMock, patch, call |
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 |
|
@@ -409,6 +409,224 @@ def test_forward_decode_only(self, mock_paged_attention,
|
409 | 409 |
|
410 | 410 | mock_paged_attention.assert_called_once()
|
411 | 411 | assert output.shape == (10, 8 * 64)
|
| 412 | + |
| 413 | + @patch('torch_npu._npu_paged_attention') |
| 414 | + @patch('torch.npu.graph_task_group_end') |
| 415 | + @patch('torch.npu.graph_task_group_begin') |
| 416 | + @patch('torch.npu.ExternalEvent') |
| 417 | + @patch('torch_npu.npu.current_stream') |
| 418 | + @patch('torch_npu._npu_paged_attention_get_workspace') |
| 419 | + def test_paged_attention_with_existing_workspace(self, mock_get_workspace, mock_current_stream, |
| 420 | + mock_external_event_class, mock_graph_begin, |
| 421 | + mock_graph_end, mock_paged_attention): |
| 422 | + """ |
| 423 | + 测试工作区已存在的情况。 |
| 424 | + """ |
| 425 | + # --- 设置 Mock 对象 --- |
| 426 | + # 模拟 graph_params 和 attn_metadata |
| 427 | + graph_params = MagicMock() |
| 428 | + attn_metadata = MagicMock() |
| 429 | + num_tokens = 10 |
| 430 | + |
| 431 | + # 模拟已存在的 workspace |
| 432 | + graph_params.workspaces = {num_tokens: 10} |
| 433 | + graph_params.events = {num_tokens: []} |
| 434 | + graph_params.attn_params = {num_tokens: []} |
| 435 | + graph_params.handles = {num_tokens: []} |
| 436 | + |
| 437 | + # 模拟其他参数 |
| 438 | + query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size] |
| 439 | + key_cache = MagicMock() |
| 440 | + value_cache = MagicMock() |
| 441 | + num_kv_heads = 4 |
| 442 | + num_heads = 8 |
| 443 | + scale = 0.1 |
| 444 | + block_tables = MagicMock() |
| 445 | + seq_lens = MagicMock() |
| 446 | + output = torch.randn(2, 5, 8) |
| 447 | + |
| 448 | + # self 对象 (模拟类实例) |
| 449 | + self_obj = MagicMock() |
| 450 | + self_obj.key_cache = key_cache |
| 451 | + self_obj.value_cache = value_cache |
| 452 | + self_obj.num_kv_heads = num_kv_heads |
| 453 | + self_obj.num_heads = num_heads |
| 454 | + self_obj.scale = scale |
| 455 | + |
| 456 | + # Mock 流和事件 |
| 457 | + mock_stream = MagicMock() |
| 458 | + mock_current_stream.return_value = mock_stream |
| 459 | + mock_event_instance = MagicMock() |
| 460 | + mock_external_event_class.return_value = mock_event_instance |
| 461 | + |
| 462 | + # Mock graph_task_group_end 返回句柄 |
| 463 | + mock_handle = MagicMock() |
| 464 | + mock_graph_end.return_value = mock_handle |
| 465 | + |
| 466 | + # --- 调用被测试的代码逻辑 (模拟) --- |
| 467 | + # 注意:这里直接模拟了代码逻辑,因为在实际应用中,这段代码可能在某个函数或方法内部 |
| 468 | + # 为了测试,我们需要将其封装或直接在这里模拟执行流程 |
| 469 | + |
| 470 | + # 1. 获取 workspace |
| 471 | + workspace = graph_params.workspaces.get(num_tokens) |
| 472 | + # 断言:工作区已存在,不应调用 get_workspace |
| 473 | + mock_get_workspace.assert_not_called() |
| 474 | + self.assertEqual(workspace, 10) |
| 475 | + |
| 476 | + # 2. Handle graph capturing mode |
| 477 | + stream = mock_current_stream() |
| 478 | + event = mock_external_event_class() |
| 479 | + event.wait(stream) |
| 480 | + event.reset(stream) |
| 481 | + graph_params.events[num_tokens].append(event) |
| 482 | + graph_params.attn_params[num_tokens].append(( |
| 483 | + query, |
| 484 | + self_obj.key_cache, |
| 485 | + self_obj.value_cache, |
| 486 | + self_obj.num_kv_heads, |
| 487 | + self_obj.num_heads, |
| 488 | + self_obj.scale, |
| 489 | + attn_metadata.block_tables, |
| 490 | + attn_metadata.seq_lens, |
| 491 | + output, |
| 492 | + )) |
| 493 | + |
| 494 | + # 断言事件调用 |
| 495 | + mock_event_instance.wait.assert_called_once_with(mock_stream) |
| 496 | + mock_event_instance.reset.assert_called_once_with(mock_stream) |
| 497 | + self.assertEqual(len(graph_params.events[num_tokens]), 1) |
| 498 | + self.assertEqual(len(graph_params.attn_params[num_tokens]), 1) |
| 499 | + |
| 500 | + # 3. Execute graph task |
| 501 | + mock_graph_begin(stream) |
| 502 | + mock_paged_attention( |
| 503 | + query=query, |
| 504 | + key_cache=self_obj.key_cache, |
| 505 | + value_cache=self_obj.value_cache, |
| 506 | + num_kv_heads=self_obj.num_kv_heads, |
| 507 | + num_heads=self_obj.num_heads, |
| 508 | + scale_value=self_obj.scale, |
| 509 | + block_table=attn_metadata.block_tables, |
| 510 | + context_lens=attn_metadata.seq_lens, |
| 511 | + out=output, |
| 512 | + workspace=workspace |
| 513 | + ) |
| 514 | + handle = mock_graph_end(stream) |
| 515 | + |
| 516 | + # 断言图任务调用 |
| 517 | + mock_graph_begin.assert_called_once_with(mock_stream) |
| 518 | + mock_graph_end.assert_called_once_with(mock_stream) |
| 519 | + self.assertEqual(handle, mock_handle) |
| 520 | + self.assertEqual(len(graph_params.handles[num_tokens]), 1) |
| 521 | + |
| 522 | + |
| 523 | + @patch('torch_npu._npu_paged_attention') |
| 524 | + @patch('torch.npu.graph_task_group_end') |
| 525 | + @patch('torch.npu.graph_task_group_begin') |
| 526 | + @patch('torch.npu.ExternalEvent') |
| 527 | + @patch('torch_npu.npu.current_stream') |
| 528 | + @patch('torch_npu._npu_paged_attention_get_workspace') |
| 529 | + def test_paged_attention_with_new_workspace(self, mock_get_workspace, mock_current_stream, |
| 530 | + mock_external_event_class, mock_graph_begin, |
| 531 | + mock_graph_end, mock_paged_attention): |
| 532 | + """ |
| 533 | + 测试工作区不存在,需要创建的情况。 |
| 534 | + """ |
| 535 | + # --- 设置 Mock 对象 --- |
| 536 | + graph_params = MagicMock() |
| 537 | + attn_metadata = MagicMock() |
| 538 | + num_tokens = 15 |
| 539 | + |
| 540 | + # 模拟不存在的 workspace |
| 541 | + graph_params.workspaces = {} |
| 542 | + graph_params.events = {num_tokens: []} |
| 543 | + graph_params.attn_params = {num_tokens: []} |
| 544 | + graph_params.handles = {num_tokens: []} |
| 545 | + |
| 546 | + |
| 547 | + query = torch.randn(1, 3, 16) |
| 548 | + key_cache = MagicMock() |
| 549 | + value_cache = MagicMock() |
| 550 | + num_kv_heads = 2 |
| 551 | + num_heads = 4 |
| 552 | + scale = 0.2 |
| 553 | + block_tables = MagicMock() |
| 554 | + seq_lens = MagicMock() |
| 555 | + output = torch.randn(1, 3, 16) |
| 556 | + |
| 557 | + self_obj = MagicMock() |
| 558 | + self_obj.key_cache = key_cache |
| 559 | + self_obj.value_cache = value_cache |
| 560 | + self_obj.num_kv_heads = num_kv_heads |
| 561 | + self_obj.num_heads = num_heads |
| 562 | + self_obj.scale = scale |
| 563 | + |
| 564 | + mock_stream = MagicMock() |
| 565 | + mock_current_stream.return_value = mock_stream |
| 566 | + mock_event_instance = MagicMock() |
| 567 | + mock_external_event_class.return_value = mock_event_instance |
| 568 | + |
| 569 | + # 模拟创建新的 workspace |
| 570 | + new_workspace = MagicMock() |
| 571 | + mock_get_workspace.return_value = new_workspace |
| 572 | + |
| 573 | + mock_handle = MagicMock() |
| 574 | + mock_graph_end.return_value = mock_handle |
| 575 | + |
| 576 | + graph_params.workspaces[num_tokens] = 10 |
| 577 | + |
| 578 | + # Handle graph capturing mode (同上,简化) |
| 579 | + stream = mock_current_stream() |
| 580 | + event = mock_external_event_class() |
| 581 | + event.wait(stream) |
| 582 | + event.reset(stream) |
| 583 | + graph_params.events[num_tokens].append(event) |
| 584 | + graph_params.attn_params[num_tokens].append(( |
| 585 | + query, |
| 586 | + self_obj.key_cache, |
| 587 | + self_obj.value_cache, |
| 588 | + self_obj.num_kv_heads, |
| 589 | + self_obj.num_heads, |
| 590 | + self_obj.scale, |
| 591 | + attn_metadata.block_tables, |
| 592 | + attn_metadata.seq_lens, |
| 593 | + output, |
| 594 | + )) |
| 595 | + |
| 596 | + # Execute graph task (同上,简化) |
| 597 | + mock_graph_begin(stream) |
| 598 | + mock_paged_attention( |
| 599 | + query=query, |
| 600 | + key_cache=self_obj.key_cache, |
| 601 | + value_cache=self_obj.value_cache, |
| 602 | + num_kv_heads=self_obj.num_kv_heads, |
| 603 | + num_heads=self_obj.num_heads, |
| 604 | + scale_value=self_obj.scale, |
| 605 | + block_table=attn_metadata.block_tables, |
| 606 | + context_lens=attn_metadata.seq_lens, |
| 607 | + out=output, |
| 608 | + workspace=10 # 使用新创建的 workspace |
| 609 | + ) |
| 610 | + handle = mock_graph_end(stream) |
| 611 | + |
| 612 | + # 断言图任务调用 (同上,简化) |
| 613 | + mock_graph_begin.assert_called_once_with(mock_stream) |
| 614 | + expected_paged_attn_call = call( |
| 615 | + query=query, |
| 616 | + key_cache=key_cache, |
| 617 | + value_cache=value_cache, |
| 618 | + num_kv_heads=num_kv_heads, |
| 619 | + num_heads=num_heads, |
| 620 | + scale_value=scale, |
| 621 | + block_table=block_tables, |
| 622 | + context_lens=seq_lens, |
| 623 | + out=output, |
| 624 | + workspace=new_workspace # 验证使用了新 workspace |
| 625 | + ) |
| 626 | + mock_paged_attention.assert_called_once_with(**expected_paged_attn_call.kwargs) |
| 627 | + mock_graph_end.assert_called_once_with(mock_stream) |
| 628 | + self.assertEqual(handle, mock_handle) |
| 629 | + self.assertEqual(len(graph_params.handles[num_tokens]), 1) |
412 | 630 |
|
413 | 631 | @patch('torch_npu._npu_reshape_and_cache')
|
414 | 632 | @patch('torch_npu.npu_fused_infer_attention_score')
|
|
0 commit comments