mirror of
https://github.com/coleam00/context-engineering-intro.git
synced 2025-12-17 17:55:29 +00:00
335 lines
13 KiB
Python
335 lines
13 KiB
Python
"""Test core agent functionality."""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, patch
|
|
from pydantic_ai.models.test import TestModel
|
|
from pydantic_ai.models.function import FunctionModel
|
|
from pydantic_ai.messages import ModelTextResponse
|
|
|
|
from ..agent import search_agent, search, SearchResponse, interactive_search
|
|
from ..dependencies import AgentDependencies
|
|
|
|
|
|
class TestAgentInitialization:
|
|
"""Test agent initialization and configuration."""
|
|
|
|
def test_agent_has_correct_model_type(self, test_agent):
|
|
"""Test agent is configured with correct model type."""
|
|
assert test_agent.model is not None
|
|
assert isinstance(test_agent.model, TestModel)
|
|
|
|
def test_agent_has_dependencies_type(self, test_agent):
|
|
"""Test agent has correct dependencies type."""
|
|
assert test_agent.deps_type == AgentDependencies
|
|
|
|
def test_agent_has_system_prompt(self, test_agent):
|
|
"""Test agent has system prompt configured."""
|
|
assert test_agent.system_prompt is not None
|
|
assert len(test_agent.system_prompt) > 0
|
|
assert "semantic search" in test_agent.system_prompt.lower()
|
|
|
|
def test_agent_has_registered_tools(self, test_agent):
|
|
"""Test agent has all required tools registered."""
|
|
tool_names = [tool.name for tool in test_agent.tool_defs]
|
|
expected_tools = ['semantic_search', 'hybrid_search', 'auto_search', 'set_search_preference']
|
|
|
|
for expected_tool in expected_tools:
|
|
assert expected_tool in tool_names, f"Missing tool: {expected_tool}"
|
|
|
|
|
|
class TestAgentBasicFunctionality:
|
|
"""Test basic agent functionality with TestModel."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_responds_to_simple_query(self, test_agent, test_dependencies):
|
|
"""Test agent provides response to simple query."""
|
|
deps, connection = test_dependencies
|
|
|
|
result = await test_agent.run(
|
|
"Search for Python tutorials",
|
|
deps=deps
|
|
)
|
|
|
|
assert result.data is not None
|
|
assert isinstance(result.data, str)
|
|
assert len(result.all_messages()) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_with_empty_query(self, test_agent, test_dependencies):
|
|
"""Test agent handles empty query gracefully."""
|
|
deps, connection = test_dependencies
|
|
|
|
result = await test_agent.run("", deps=deps)
|
|
|
|
# Should still provide a response
|
|
assert result.data is not None
|
|
assert isinstance(result.data, str)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_with_long_query(self, test_agent, test_dependencies):
|
|
"""Test agent handles long queries."""
|
|
deps, connection = test_dependencies
|
|
|
|
long_query = "This is a very long query " * 50 # 350+ characters
|
|
result = await test_agent.run(long_query, deps=deps)
|
|
|
|
assert result.data is not None
|
|
assert isinstance(result.data, str)
|
|
|
|
|
|
class TestAgentToolCalling:
|
|
"""Test agent tool calling behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_calls_search_tools(self, test_dependencies, mock_database_responses):
|
|
"""Test agent calls appropriate search tools."""
|
|
deps, connection = test_dependencies
|
|
|
|
# Configure mock database responses
|
|
connection.fetch.return_value = mock_database_responses['semantic_search']
|
|
|
|
# Create function model that calls tools
|
|
call_count = 0
|
|
|
|
async def search_function(messages, tools):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
|
|
if call_count == 1:
|
|
return ModelTextResponse(content="I'll search for that information.")
|
|
elif call_count == 2:
|
|
return {"auto_search": {"query": "test query", "match_count": 10}}
|
|
else:
|
|
return ModelTextResponse(content="Based on the search results, here's what I found...")
|
|
|
|
function_model = FunctionModel(search_function)
|
|
test_agent = search_agent.override(model=function_model)
|
|
|
|
result = await test_agent.run("Search for Python tutorials", deps=deps)
|
|
|
|
# Verify tool was called
|
|
messages = result.all_messages()
|
|
tool_calls = [msg for msg in messages if hasattr(msg, 'tool_name')]
|
|
assert len(tool_calls) > 0, "No tool calls found"
|
|
|
|
# Verify auto_search was called
|
|
auto_search_calls = [msg for msg in tool_calls if getattr(msg, 'tool_name', None) == 'auto_search']
|
|
assert len(auto_search_calls) > 0, "auto_search tool was not called"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_calls_preference_tool(self, test_dependencies):
|
|
"""Test agent calls preference setting tool."""
|
|
deps, connection = test_dependencies
|
|
|
|
call_count = 0
|
|
|
|
async def preference_function(messages, tools):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
|
|
if call_count == 1:
|
|
return {"set_search_preference": {"preference_type": "search_type", "value": "semantic"}}
|
|
else:
|
|
return ModelTextResponse(content="Preference set successfully.")
|
|
|
|
function_model = FunctionModel(preference_function)
|
|
test_agent = search_agent.override(model=function_model)
|
|
|
|
result = await test_agent.run("Set search preference to semantic", deps=deps)
|
|
|
|
# Verify preference was set
|
|
assert deps.user_preferences.get('search_type') == 'semantic'
|
|
assert result.data is not None
|
|
|
|
|
|
class TestSearchFunction:
|
|
"""Test the standalone search function."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_function_with_defaults(self):
|
|
"""Test search function with default parameters."""
|
|
with patch('..agent.search_agent') as mock_agent:
|
|
# Mock agent run result
|
|
mock_result = AsyncMock()
|
|
mock_result.data = "Search results found"
|
|
mock_agent.run.return_value = mock_result
|
|
|
|
response = await search("test query")
|
|
|
|
assert isinstance(response, SearchResponse)
|
|
assert response.summary == "Search results found"
|
|
assert response.search_strategy == "auto"
|
|
assert response.result_count == 10
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_function_with_custom_params(self):
|
|
"""Test search function with custom parameters."""
|
|
with patch('..agent.search_agent') as mock_agent:
|
|
mock_result = AsyncMock()
|
|
mock_result.data = "Custom search results"
|
|
mock_agent.run.return_value = mock_result
|
|
|
|
response = await search(
|
|
query="custom query",
|
|
search_type="semantic",
|
|
match_count=20,
|
|
text_weight=0.5
|
|
)
|
|
|
|
assert isinstance(response, SearchResponse)
|
|
assert response.summary == "Custom search results"
|
|
assert response.result_count == 20
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_function_with_existing_deps(self, test_dependencies):
|
|
"""Test search function with provided dependencies."""
|
|
deps, connection = test_dependencies
|
|
|
|
with patch('..agent.search_agent') as mock_agent:
|
|
mock_result = AsyncMock()
|
|
mock_result.data = "Search with deps"
|
|
mock_agent.run.return_value = mock_result
|
|
|
|
response = await search("test query", deps=deps)
|
|
|
|
assert isinstance(response, SearchResponse)
|
|
assert response.summary == "Search with deps"
|
|
# Should not call cleanup since deps were provided
|
|
assert deps.db_pool is not None
|
|
|
|
|
|
class TestInteractiveSearch:
|
|
"""Test interactive search functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_interactive_search_creates_deps(self):
|
|
"""Test interactive search creates new dependencies."""
|
|
with patch.object(AgentDependencies, 'initialize') as mock_init:
|
|
deps = await interactive_search()
|
|
|
|
assert isinstance(deps, AgentDependencies)
|
|
assert deps.session_id is not None
|
|
mock_init.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_interactive_search_reuses_deps(self, test_dependencies):
|
|
"""Test interactive search reuses existing dependencies."""
|
|
existing_deps, connection = test_dependencies
|
|
|
|
deps = await interactive_search(existing_deps)
|
|
|
|
assert deps is existing_deps
|
|
assert deps.session_id == "test_session"
|
|
|
|
|
|
class TestAgentErrorHandling:
|
|
"""Test agent error handling scenarios."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_handles_database_error(self, test_agent, test_dependencies):
|
|
"""Test agent handles database connection errors."""
|
|
deps, connection = test_dependencies
|
|
|
|
# Simulate database error
|
|
connection.fetch.side_effect = Exception("Database connection failed")
|
|
|
|
# Should not raise exception, agent should handle gracefully
|
|
result = await test_agent.run("Search for something", deps=deps)
|
|
|
|
assert result.data is not None
|
|
# Agent should provide some response even if search fails
|
|
assert isinstance(result.data, str)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_handles_invalid_dependencies(self, test_agent):
|
|
"""Test agent behavior with invalid dependencies."""
|
|
# Create deps without proper initialization
|
|
invalid_deps = AgentDependencies()
|
|
|
|
# Should handle missing database pool gracefully
|
|
result = await test_agent.run("Search query", deps=invalid_deps)
|
|
|
|
assert result.data is not None
|
|
assert isinstance(result.data, str)
|
|
|
|
|
|
class TestAgentResponseQuality:
|
|
"""Test quality of agent responses."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_response_mentions_search(self, test_agent, test_dependencies):
|
|
"""Test agent response mentions search-related terms."""
|
|
deps, connection = test_dependencies
|
|
|
|
result = await test_agent.run("Find information about machine learning", deps=deps)
|
|
|
|
response_lower = result.data.lower()
|
|
search_terms = ['search', 'find', 'information', 'results']
|
|
|
|
# At least one search-related term should be mentioned
|
|
assert any(term in response_lower for term in search_terms)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_response_reasonable_length(self, test_agent, test_dependencies):
|
|
"""Test agent responses are reasonable length."""
|
|
deps, connection = test_dependencies
|
|
|
|
result = await test_agent.run("What is Python?", deps=deps)
|
|
|
|
# Response should be substantial but not excessive
|
|
assert 10 <= len(result.data) <= 2000
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_handles_different_query_types(self, test_agent, test_dependencies):
|
|
"""Test agent handles different types of queries."""
|
|
deps, connection = test_dependencies
|
|
|
|
queries = [
|
|
"What is Python?", # Conceptual
|
|
"Find exact quote about 'machine learning'", # Exact match
|
|
"Show me tutorials", # General
|
|
"API documentation for requests library" # Technical
|
|
]
|
|
|
|
for query in queries:
|
|
result = await test_agent.run(query, deps=deps)
|
|
|
|
assert result.data is not None
|
|
assert isinstance(result.data, str)
|
|
assert len(result.data) > 0
|
|
|
|
|
|
class TestAgentMemoryAndContext:
|
|
"""Test agent memory and context handling."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_maintains_session_context(self, test_dependencies):
|
|
"""Test agent can maintain session context."""
|
|
deps, connection = test_dependencies
|
|
|
|
# Set some preferences
|
|
deps.set_user_preference('search_type', 'semantic')
|
|
deps.add_to_history('previous query')
|
|
|
|
test_agent = search_agent.override(model=TestModel())
|
|
|
|
result = await test_agent.run("Another query", deps=deps)
|
|
|
|
# Verify context is maintained
|
|
assert deps.user_preferences['search_type'] == 'semantic'
|
|
assert 'previous query' in deps.query_history
|
|
assert result.data is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_query_history_limit(self, test_dependencies):
|
|
"""Test query history is properly limited."""
|
|
deps, connection = test_dependencies
|
|
|
|
# Add more than 10 queries
|
|
for i in range(15):
|
|
deps.add_to_history(f"query {i}")
|
|
|
|
# Should only keep last 10
|
|
assert len(deps.query_history) == 10
|
|
assert deps.query_history[0] == "query 5"
|
|
assert deps.query_history[-1] == "query 14" |