Fix issue #2343: Add Ollama monkey patch for local LLM integration

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-03-12 09:54:21 +00:00
parent 41a670166a
commit d049f56986
7 changed files with 617 additions and 1 deletions

View File

@@ -0,0 +1,3 @@
"""
Tests for Ollama integration utilities.
"""

View File

@@ -0,0 +1,253 @@
"""
Tests for the Ollama monkey patch utility.
"""
import unittest
from unittest.mock import patch, MagicMock, call
import json
from types import SimpleNamespace
import pytest
from crewai.utilities.ollama.monkey_patch import (
apply_monkey_patch,
query_ollama,
extract_prompt_from_messages
)
class TestOllamaMonkeyPatch(unittest.TestCase):
"""Test cases for the Ollama monkey patch utility."""
def test_extract_prompt_from_messages(self):
"""Test extracting a prompt from a list of messages."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you!"},
{"role": "user", "content": "Tell me about CrewAI."}
]
prompt = extract_prompt_from_messages(messages)
self.assertIn("System:", prompt)
self.assertIn("You are a helpful assistant.", prompt)
self.assertIn("User:", prompt)
self.assertIn("Hello, how are you?", prompt)
self.assertIn("Assistant:", prompt)
self.assertIn("I'm doing well, thank you!", prompt)
self.assertIn("Tell me about CrewAI.", prompt)
@patch('requests.post')
def test_query_ollama_non_streaming(self, mock_post):
"""Test querying Ollama API in non-streaming mode."""
# Mock the response
mock_response = MagicMock()
mock_response.json.return_value = {"response": "This is a test response."}
mock_post.return_value = mock_response
# Call the function
result = query_ollama(
prompt="Test prompt",
model="llama3",
base_url="http://localhost:11434",
stream=False,
temperature=0.5
)
# Verify the result
self.assertEqual(result, "This is a test response.")
# Verify the API call
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
self.assertEqual(args[0], "http://localhost:11434/api/generate")
self.assertEqual(kwargs["json"]["model"], "llama3")
self.assertEqual(kwargs["json"]["prompt"], "Test prompt")
self.assertEqual(kwargs["json"]["options"]["temperature"], 0.5)
self.assertEqual(kwargs["json"]["options"]["stream"], False)
@patch('requests.post')
def test_query_ollama_streaming(self, mock_post):
"""Test querying Ollama API in streaming mode."""
# Mock the response for streaming
mock_response = MagicMock()
mock_response.iter_lines.return_value = [
json.dumps({"response": "This"}).encode(),
json.dumps({"response": " is"}).encode(),
json.dumps({"response": " a"}).encode(),
json.dumps({"response": " test"}).encode(),
json.dumps({"response": " response.", "done": True}).encode()
]
mock_post.return_value = mock_response
# Call the function
result = query_ollama(
prompt="Test prompt",
model="llama3",
base_url="http://localhost:11434",
stream=True,
temperature=0.5
)
# Verify the result is a generator
self.assertTrue(hasattr(result, '__next__'))
# Consume the generator and verify the results
chunks = list(result)
self.assertEqual(chunks, ["This", " is", " a", " test", " response."])
# Verify the API call
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
self.assertEqual(args[0], "http://localhost:11434/api/generate")
self.assertEqual(kwargs["json"]["model"], "llama3")
self.assertEqual(kwargs["json"]["prompt"], "Test prompt")
self.assertEqual(kwargs["json"]["options"]["temperature"], 0.5)
self.assertEqual(kwargs["json"]["options"]["stream"], True)
@patch('requests.post')
def test_query_ollama_with_stop_sequences(self, mock_post):
"""Test querying Ollama API with stop sequences."""
# Mock the response
mock_response = MagicMock()
mock_response.json.return_value = {"response": "This is a test response."}
mock_post.return_value = mock_response
# Call the function with stop sequences
result = query_ollama(
prompt="Test prompt",
model="llama3",
stop=["END", "STOP"]
)
# Verify the API call includes stop sequences
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
self.assertEqual(kwargs["json"]["options"]["stop"], ["END", "STOP"])
@patch('requests.post')
def test_query_ollama_error_handling(self, mock_post):
"""Test error handling in query_ollama."""
# Mock the response to raise an exception
mock_post.side_effect = Exception("Test error")
# Call the function
result = query_ollama(prompt="Test prompt")
# Verify the result contains the error message
self.assertIn("Error:", result)
self.assertIn("Test error", result)
@patch('litellm.completion')
def test_apply_monkey_patch(self, mock_completion):
"""Test applying the monkey patch."""
# Apply the monkey patch
result = apply_monkey_patch()
# Verify the result
self.assertTrue(result)
# Verify that litellm.completion has been replaced
import litellm
self.assertNotEqual(litellm.completion, mock_completion)
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
@patch('litellm.completion')
def test_custom_completion_non_ollama_model(self, mock_original_completion, mock_query_ollama):
"""Test that non-Ollama models are passed to the original completion function."""
# Apply the monkey patch
apply_monkey_patch()
# Import litellm to get the patched completion function
import litellm
# Call the patched completion function with a non-Ollama model
litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello"}]
)
# Verify that the original completion function was called
mock_original_completion.assert_called_once()
# Verify that query_ollama was not called
mock_query_ollama.assert_not_called()
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
@patch('litellm.completion')
def test_custom_completion_ollama_model_non_streaming(self, mock_original_completion, mock_query_ollama):
"""Test the custom completion function with an Ollama model in non-streaming mode."""
# Set up the mock
mock_query_ollama.return_value = "This is a test response."
# Apply the monkey patch
apply_monkey_patch()
# Import litellm to get the patched completion function
import litellm
# Call the patched completion function with an Ollama model
result = litellm.completion(
model="ollama/llama3",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.5
)
# Verify that the original completion function was not called
mock_original_completion.assert_not_called()
# Verify that query_ollama was called
mock_query_ollama.assert_called_once()
# Verify the result structure
self.assertEqual(result.choices[0].message.content, "This is a test response.")
self.assertEqual(result.choices[0].finish_reason, "stop")
self.assertEqual(result.model, "ollama/llama3")
self.assertIsNotNone(result.usage)
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
@patch('litellm.completion')
def test_custom_completion_ollama_model_streaming(self, mock_original_completion, mock_query_ollama):
"""Test the custom completion function with an Ollama model in streaming mode."""
# Set up the mock to return a generator
mock_query_ollama.return_value = (chunk for chunk in ["This", " is", " a", " test", " response."])
# Apply the monkey patch
apply_monkey_patch()
# Import litellm to get the patched completion function
import litellm
# Call the patched completion function with an Ollama model in streaming mode
result = litellm.completion(
model="ollama/llama3",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.5,
stream=True
)
# Verify that the original completion function was not called
mock_original_completion.assert_not_called()
# Verify that query_ollama was called
mock_query_ollama.assert_called_once()
# Verify the result is a generator
self.assertTrue(hasattr(result, '__next__'))
# Consume the generator and verify the structure of each chunk
chunks = list(result)
# Verify we have the expected number of chunks (5 content chunks + 1 final chunk)
self.assertEqual(len(chunks), 6)
# Check the content of the first 5 chunks
for i, expected_content in enumerate(["This", " is", " a", " test", " response."]):
self.assertEqual(chunks[i].choices[0].delta.content, expected_content)
self.assertEqual(chunks[i].choices[0].delta.role, "assistant")
self.assertIsNone(chunks[i].choices[0].finish_reason)
# Check the final chunk
self.assertEqual(chunks[5].choices[0].delta.content, "")
self.assertEqual(chunks[5].choices[0].finish_reason, "stop")
self.assertIsNotNone(chunks[5].usage)