mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
254 lines
9.8 KiB
Python
254 lines
9.8 KiB
Python
"""
|
|
Tests for the Ollama monkey patch utility.
|
|
"""
|
|
|
|
import unittest
|
|
from unittest.mock import patch, MagicMock, call
|
|
import json
|
|
from types import SimpleNamespace
|
|
import pytest
|
|
|
|
from crewai.utilities.ollama.monkey_patch import (
|
|
apply_monkey_patch,
|
|
query_ollama,
|
|
extract_prompt_from_messages
|
|
)
|
|
|
|
|
|
class TestOllamaMonkeyPatch(unittest.TestCase):
|
|
"""Test cases for the Ollama monkey patch utility."""
|
|
|
|
def test_extract_prompt_from_messages(self):
|
|
"""Test extracting a prompt from a list of messages."""
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": "Hello, how are you?"},
|
|
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
|
{"role": "user", "content": "Tell me about CrewAI."}
|
|
]
|
|
|
|
prompt = extract_prompt_from_messages(messages)
|
|
|
|
self.assertIn("System:", prompt)
|
|
self.assertIn("You are a helpful assistant.", prompt)
|
|
self.assertIn("User:", prompt)
|
|
self.assertIn("Hello, how are you?", prompt)
|
|
self.assertIn("Assistant:", prompt)
|
|
self.assertIn("I'm doing well, thank you!", prompt)
|
|
self.assertIn("Tell me about CrewAI.", prompt)
|
|
|
|
@patch('requests.post')
|
|
def test_query_ollama_non_streaming(self, mock_post):
|
|
"""Test querying Ollama API in non-streaming mode."""
|
|
# Mock the response
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {"response": "This is a test response."}
|
|
mock_post.return_value = mock_response
|
|
|
|
# Call the function
|
|
result = query_ollama(
|
|
prompt="Test prompt",
|
|
model="llama3",
|
|
base_url="http://localhost:11434",
|
|
stream=False,
|
|
temperature=0.5
|
|
)
|
|
|
|
# Verify the result
|
|
self.assertEqual(result, "This is a test response.")
|
|
|
|
# Verify the API call
|
|
mock_post.assert_called_once()
|
|
args, kwargs = mock_post.call_args
|
|
self.assertEqual(args[0], "http://localhost:11434/api/generate")
|
|
self.assertEqual(kwargs["json"]["model"], "llama3")
|
|
self.assertEqual(kwargs["json"]["prompt"], "Test prompt")
|
|
self.assertEqual(kwargs["json"]["options"]["temperature"], 0.5)
|
|
self.assertEqual(kwargs["json"]["options"]["stream"], False)
|
|
|
|
@patch('requests.post')
|
|
def test_query_ollama_streaming(self, mock_post):
|
|
"""Test querying Ollama API in streaming mode."""
|
|
# Mock the response for streaming
|
|
mock_response = MagicMock()
|
|
mock_response.iter_lines.return_value = [
|
|
json.dumps({"response": "This"}).encode(),
|
|
json.dumps({"response": " is"}).encode(),
|
|
json.dumps({"response": " a"}).encode(),
|
|
json.dumps({"response": " test"}).encode(),
|
|
json.dumps({"response": " response.", "done": True}).encode()
|
|
]
|
|
mock_post.return_value = mock_response
|
|
|
|
# Call the function
|
|
result = query_ollama(
|
|
prompt="Test prompt",
|
|
model="llama3",
|
|
base_url="http://localhost:11434",
|
|
stream=True,
|
|
temperature=0.5
|
|
)
|
|
|
|
# Verify the result is a generator
|
|
self.assertTrue(hasattr(result, '__next__'))
|
|
|
|
# Consume the generator and verify the results
|
|
chunks = list(result)
|
|
self.assertEqual(chunks, ["This", " is", " a", " test", " response."])
|
|
|
|
# Verify the API call
|
|
mock_post.assert_called_once()
|
|
args, kwargs = mock_post.call_args
|
|
self.assertEqual(args[0], "http://localhost:11434/api/generate")
|
|
self.assertEqual(kwargs["json"]["model"], "llama3")
|
|
self.assertEqual(kwargs["json"]["prompt"], "Test prompt")
|
|
self.assertEqual(kwargs["json"]["options"]["temperature"], 0.5)
|
|
self.assertEqual(kwargs["json"]["options"]["stream"], True)
|
|
|
|
@patch('requests.post')
|
|
def test_query_ollama_with_stop_sequences(self, mock_post):
|
|
"""Test querying Ollama API with stop sequences."""
|
|
# Mock the response
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {"response": "This is a test response."}
|
|
mock_post.return_value = mock_response
|
|
|
|
# Call the function with stop sequences
|
|
result = query_ollama(
|
|
prompt="Test prompt",
|
|
model="llama3",
|
|
stop=["END", "STOP"]
|
|
)
|
|
|
|
# Verify the API call includes stop sequences
|
|
mock_post.assert_called_once()
|
|
args, kwargs = mock_post.call_args
|
|
self.assertEqual(kwargs["json"]["options"]["stop"], ["END", "STOP"])
|
|
|
|
@patch('requests.post')
|
|
def test_query_ollama_error_handling(self, mock_post):
|
|
"""Test error handling in query_ollama."""
|
|
# Mock the response to raise an exception
|
|
mock_post.side_effect = Exception("Test error")
|
|
|
|
# Call the function
|
|
result = query_ollama(prompt="Test prompt")
|
|
|
|
# Verify the result contains the error message
|
|
self.assertIn("Error:", result)
|
|
self.assertIn("Test error", result)
|
|
|
|
@patch('litellm.completion')
|
|
def test_apply_monkey_patch(self, mock_completion):
|
|
"""Test applying the monkey patch."""
|
|
# Apply the monkey patch
|
|
result = apply_monkey_patch()
|
|
|
|
# Verify the result
|
|
self.assertTrue(result)
|
|
|
|
# Verify that litellm.completion has been replaced
|
|
import litellm
|
|
self.assertNotEqual(litellm.completion, mock_completion)
|
|
|
|
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
|
|
@patch('litellm.completion')
|
|
def test_custom_completion_non_ollama_model(self, mock_original_completion, mock_query_ollama):
|
|
"""Test that non-Ollama models are passed to the original completion function."""
|
|
# Apply the monkey patch
|
|
apply_monkey_patch()
|
|
|
|
# Import litellm to get the patched completion function
|
|
import litellm
|
|
|
|
# Call the patched completion function with a non-Ollama model
|
|
litellm.completion(
|
|
model="gpt-3.5-turbo",
|
|
messages=[{"role": "user", "content": "Hello"}]
|
|
)
|
|
|
|
# Verify that the original completion function was called
|
|
mock_original_completion.assert_called_once()
|
|
|
|
# Verify that query_ollama was not called
|
|
mock_query_ollama.assert_not_called()
|
|
|
|
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
|
|
@patch('litellm.completion')
|
|
def test_custom_completion_ollama_model_non_streaming(self, mock_original_completion, mock_query_ollama):
|
|
"""Test the custom completion function with an Ollama model in non-streaming mode."""
|
|
# Set up the mock
|
|
mock_query_ollama.return_value = "This is a test response."
|
|
|
|
# Apply the monkey patch
|
|
apply_monkey_patch()
|
|
|
|
# Import litellm to get the patched completion function
|
|
import litellm
|
|
|
|
# Call the patched completion function with an Ollama model
|
|
result = litellm.completion(
|
|
model="ollama/llama3",
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
temperature=0.5
|
|
)
|
|
|
|
# Verify that the original completion function was not called
|
|
mock_original_completion.assert_not_called()
|
|
|
|
# Verify that query_ollama was called
|
|
mock_query_ollama.assert_called_once()
|
|
|
|
# Verify the result structure
|
|
self.assertEqual(result.choices[0].message.content, "This is a test response.")
|
|
self.assertEqual(result.choices[0].finish_reason, "stop")
|
|
self.assertEqual(result.model, "ollama/llama3")
|
|
self.assertIsNotNone(result.usage)
|
|
|
|
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
|
|
@patch('litellm.completion')
|
|
def test_custom_completion_ollama_model_streaming(self, mock_original_completion, mock_query_ollama):
|
|
"""Test the custom completion function with an Ollama model in streaming mode."""
|
|
# Set up the mock to return a generator
|
|
mock_query_ollama.return_value = (chunk for chunk in ["This", " is", " a", " test", " response."])
|
|
|
|
# Apply the monkey patch
|
|
apply_monkey_patch()
|
|
|
|
# Import litellm to get the patched completion function
|
|
import litellm
|
|
|
|
# Call the patched completion function with an Ollama model in streaming mode
|
|
result = litellm.completion(
|
|
model="ollama/llama3",
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
temperature=0.5,
|
|
stream=True
|
|
)
|
|
|
|
# Verify that the original completion function was not called
|
|
mock_original_completion.assert_not_called()
|
|
|
|
# Verify that query_ollama was called
|
|
mock_query_ollama.assert_called_once()
|
|
|
|
# Verify the result is a generator
|
|
self.assertTrue(hasattr(result, '__next__'))
|
|
|
|
# Consume the generator and verify the structure of each chunk
|
|
chunks = list(result)
|
|
|
|
# Verify we have the expected number of chunks (5 content chunks + 1 final chunk)
|
|
self.assertEqual(len(chunks), 6)
|
|
|
|
# Check the content of the first 5 chunks
|
|
for i, expected_content in enumerate(["This", " is", " a", " test", " response."]):
|
|
self.assertEqual(chunks[i].choices[0].delta.content, expected_content)
|
|
self.assertEqual(chunks[i].choices[0].delta.role, "assistant")
|
|
self.assertIsNone(chunks[i].choices[0].finish_reason)
|
|
|
|
# Check the final chunk
|
|
self.assertEqual(chunks[5].choices[0].delta.content, "")
|
|
self.assertEqual(chunks[5].choices[0].finish_reason, "stop")
|
|
self.assertIsNotNone(chunks[5].usage)
|