2025-08-22 21:01:17 -05:00

335 lines
13 KiB
Python

"""Test core agent functionality."""
import pytest
from unittest.mock import AsyncMock, patch
from pydantic_ai.models.test import TestModel
from pydantic_ai.models.function import FunctionModel
from pydantic_ai.messages import ModelTextResponse
from ..agent import search_agent, search, SearchResponse, interactive_search
from ..dependencies import AgentDependencies
class TestAgentInitialization:
"""Test agent initialization and configuration."""
def test_agent_has_correct_model_type(self, test_agent):
"""Test agent is configured with correct model type."""
assert test_agent.model is not None
assert isinstance(test_agent.model, TestModel)
def test_agent_has_dependencies_type(self, test_agent):
"""Test agent has correct dependencies type."""
assert test_agent.deps_type == AgentDependencies
def test_agent_has_system_prompt(self, test_agent):
"""Test agent has system prompt configured."""
assert test_agent.system_prompt is not None
assert len(test_agent.system_prompt) > 0
assert "semantic search" in test_agent.system_prompt.lower()
def test_agent_has_registered_tools(self, test_agent):
"""Test agent has all required tools registered."""
tool_names = [tool.name for tool in test_agent.tool_defs]
expected_tools = ['semantic_search', 'hybrid_search', 'auto_search', 'set_search_preference']
for expected_tool in expected_tools:
assert expected_tool in tool_names, f"Missing tool: {expected_tool}"
class TestAgentBasicFunctionality:
"""Test basic agent functionality with TestModel."""
@pytest.mark.asyncio
async def test_agent_responds_to_simple_query(self, test_agent, test_dependencies):
"""Test agent provides response to simple query."""
deps, connection = test_dependencies
result = await test_agent.run(
"Search for Python tutorials",
deps=deps
)
assert result.data is not None
assert isinstance(result.data, str)
assert len(result.all_messages()) > 0
@pytest.mark.asyncio
async def test_agent_with_empty_query(self, test_agent, test_dependencies):
"""Test agent handles empty query gracefully."""
deps, connection = test_dependencies
result = await test_agent.run("", deps=deps)
# Should still provide a response
assert result.data is not None
assert isinstance(result.data, str)
@pytest.mark.asyncio
async def test_agent_with_long_query(self, test_agent, test_dependencies):
"""Test agent handles long queries."""
deps, connection = test_dependencies
long_query = "This is a very long query " * 50 # 350+ characters
result = await test_agent.run(long_query, deps=deps)
assert result.data is not None
assert isinstance(result.data, str)
class TestAgentToolCalling:
"""Test agent tool calling behavior."""
@pytest.mark.asyncio
async def test_agent_calls_search_tools(self, test_dependencies, mock_database_responses):
"""Test agent calls appropriate search tools."""
deps, connection = test_dependencies
# Configure mock database responses
connection.fetch.return_value = mock_database_responses['semantic_search']
# Create function model that calls tools
call_count = 0
async def search_function(messages, tools):
nonlocal call_count
call_count += 1
if call_count == 1:
return ModelTextResponse(content="I'll search for that information.")
elif call_count == 2:
return {"auto_search": {"query": "test query", "match_count": 10}}
else:
return ModelTextResponse(content="Based on the search results, here's what I found...")
function_model = FunctionModel(search_function)
test_agent = search_agent.override(model=function_model)
result = await test_agent.run("Search for Python tutorials", deps=deps)
# Verify tool was called
messages = result.all_messages()
tool_calls = [msg for msg in messages if hasattr(msg, 'tool_name')]
assert len(tool_calls) > 0, "No tool calls found"
# Verify auto_search was called
auto_search_calls = [msg for msg in tool_calls if getattr(msg, 'tool_name', None) == 'auto_search']
assert len(auto_search_calls) > 0, "auto_search tool was not called"
@pytest.mark.asyncio
async def test_agent_calls_preference_tool(self, test_dependencies):
"""Test agent calls preference setting tool."""
deps, connection = test_dependencies
call_count = 0
async def preference_function(messages, tools):
nonlocal call_count
call_count += 1
if call_count == 1:
return {"set_search_preference": {"preference_type": "search_type", "value": "semantic"}}
else:
return ModelTextResponse(content="Preference set successfully.")
function_model = FunctionModel(preference_function)
test_agent = search_agent.override(model=function_model)
result = await test_agent.run("Set search preference to semantic", deps=deps)
# Verify preference was set
assert deps.user_preferences.get('search_type') == 'semantic'
assert result.data is not None
class TestSearchFunction:
"""Test the standalone search function."""
@pytest.mark.asyncio
async def test_search_function_with_defaults(self):
"""Test search function with default parameters."""
with patch('..agent.search_agent') as mock_agent:
# Mock agent run result
mock_result = AsyncMock()
mock_result.data = "Search results found"
mock_agent.run.return_value = mock_result
response = await search("test query")
assert isinstance(response, SearchResponse)
assert response.summary == "Search results found"
assert response.search_strategy == "auto"
assert response.result_count == 10
@pytest.mark.asyncio
async def test_search_function_with_custom_params(self):
"""Test search function with custom parameters."""
with patch('..agent.search_agent') as mock_agent:
mock_result = AsyncMock()
mock_result.data = "Custom search results"
mock_agent.run.return_value = mock_result
response = await search(
query="custom query",
search_type="semantic",
match_count=20,
text_weight=0.5
)
assert isinstance(response, SearchResponse)
assert response.summary == "Custom search results"
assert response.result_count == 20
@pytest.mark.asyncio
async def test_search_function_with_existing_deps(self, test_dependencies):
"""Test search function with provided dependencies."""
deps, connection = test_dependencies
with patch('..agent.search_agent') as mock_agent:
mock_result = AsyncMock()
mock_result.data = "Search with deps"
mock_agent.run.return_value = mock_result
response = await search("test query", deps=deps)
assert isinstance(response, SearchResponse)
assert response.summary == "Search with deps"
# Should not call cleanup since deps were provided
assert deps.db_pool is not None
class TestInteractiveSearch:
"""Test interactive search functionality."""
@pytest.mark.asyncio
async def test_interactive_search_creates_deps(self):
"""Test interactive search creates new dependencies."""
with patch.object(AgentDependencies, 'initialize') as mock_init:
deps = await interactive_search()
assert isinstance(deps, AgentDependencies)
assert deps.session_id is not None
mock_init.assert_called_once()
@pytest.mark.asyncio
async def test_interactive_search_reuses_deps(self, test_dependencies):
"""Test interactive search reuses existing dependencies."""
existing_deps, connection = test_dependencies
deps = await interactive_search(existing_deps)
assert deps is existing_deps
assert deps.session_id == "test_session"
class TestAgentErrorHandling:
"""Test agent error handling scenarios."""
@pytest.mark.asyncio
async def test_agent_handles_database_error(self, test_agent, test_dependencies):
"""Test agent handles database connection errors."""
deps, connection = test_dependencies
# Simulate database error
connection.fetch.side_effect = Exception("Database connection failed")
# Should not raise exception, agent should handle gracefully
result = await test_agent.run("Search for something", deps=deps)
assert result.data is not None
# Agent should provide some response even if search fails
assert isinstance(result.data, str)
@pytest.mark.asyncio
async def test_agent_handles_invalid_dependencies(self, test_agent):
"""Test agent behavior with invalid dependencies."""
# Create deps without proper initialization
invalid_deps = AgentDependencies()
# Should handle missing database pool gracefully
result = await test_agent.run("Search query", deps=invalid_deps)
assert result.data is not None
assert isinstance(result.data, str)
class TestAgentResponseQuality:
"""Test quality of agent responses."""
@pytest.mark.asyncio
async def test_agent_response_mentions_search(self, test_agent, test_dependencies):
"""Test agent response mentions search-related terms."""
deps, connection = test_dependencies
result = await test_agent.run("Find information about machine learning", deps=deps)
response_lower = result.data.lower()
search_terms = ['search', 'find', 'information', 'results']
# At least one search-related term should be mentioned
assert any(term in response_lower for term in search_terms)
@pytest.mark.asyncio
async def test_agent_response_reasonable_length(self, test_agent, test_dependencies):
"""Test agent responses are reasonable length."""
deps, connection = test_dependencies
result = await test_agent.run("What is Python?", deps=deps)
# Response should be substantial but not excessive
assert 10 <= len(result.data) <= 2000
@pytest.mark.asyncio
async def test_agent_handles_different_query_types(self, test_agent, test_dependencies):
"""Test agent handles different types of queries."""
deps, connection = test_dependencies
queries = [
"What is Python?", # Conceptual
"Find exact quote about 'machine learning'", # Exact match
"Show me tutorials", # General
"API documentation for requests library" # Technical
]
for query in queries:
result = await test_agent.run(query, deps=deps)
assert result.data is not None
assert isinstance(result.data, str)
assert len(result.data) > 0
class TestAgentMemoryAndContext:
"""Test agent memory and context handling."""
@pytest.mark.asyncio
async def test_agent_maintains_session_context(self, test_dependencies):
"""Test agent can maintain session context."""
deps, connection = test_dependencies
# Set some preferences
deps.set_user_preference('search_type', 'semantic')
deps.add_to_history('previous query')
test_agent = search_agent.override(model=TestModel())
result = await test_agent.run("Another query", deps=deps)
# Verify context is maintained
assert deps.user_preferences['search_type'] == 'semantic'
assert 'previous query' in deps.query_history
assert result.data is not None
@pytest.mark.asyncio
async def test_agent_query_history_limit(self, test_dependencies):
"""Test query history is properly limited."""
deps, connection = test_dependencies
# Add more than 10 queries
for i in range(15):
deps.add_to_history(f"query {i}")
# Should only keep last 10
assert len(deps.query_history) == 10
assert deps.query_history[0] == "query 5"
assert deps.query_history[-1] == "query 14"