Compare commits

...

1 Commits

Author SHA1 Message Date
Devin AI
d049f56986 Fix issue #2343: Add Ollama monkey patch for local LLM integration
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-12 09:54:21 +00:00
7 changed files with 617 additions and 1 deletions

View File

@@ -142,7 +142,9 @@ You can connect to OpenAI-compatible LLMs using either environment variables or
## Using Local Models with Ollama
For local models like those provided by Ollama:
CrewAI provides two ways to use local models with Ollama:
### Method 1: Direct Connection (Standard)
<Steps>
<Step title="Download and install Ollama">
@@ -165,6 +167,49 @@ For local models like those provided by Ollama:
</Step>
</Steps>
### Method 2: Using the Ollama Monkey Patch (Recommended)
For a more robust integration with Ollama, CrewAI provides a monkey patch that enhances compatibility and performance:
<Steps>
<Step title="Download and install Ollama">
[Click here to download and install Ollama](https://ollama.com/download)
</Step>
<Step title="Pull the desired model">
For example, run `ollama pull llama3` to download the model.
</Step>
<Step title="Apply the monkey patch">
<CodeGroup>
```python Code
from crewai import Agent, Crew, Task, LLM
from crewai import apply_monkey_patch
# Apply the monkey patch at the beginning of your script
apply_monkey_patch()
# Create an LLM instance with an Ollama model
llm = LLM(model="ollama/llama3", base_url="http://localhost:11434")
# Use the LLM instance with CrewAI
agent = Agent(
role='Local AI Expert',
goal='Process information using a local model',
backstory="An AI assistant running on local hardware.",
llm=llm
)
```
</CodeGroup>
</Step>
</Steps>
The monkey patch provides several advantages:
- Improved handling of streaming responses
- Better error handling and logging
- More accurate token counting
- Enhanced compatibility with CrewAI's features
For more details, see the [Ollama integration README](https://github.com/crewAIinc/crewAI/blob/main/src/crewai/utilities/ollama/README.md).
## Changing the Base API URL
You can change the base API URL for any LLM provider by setting the `base_url` parameter:

View File

@@ -7,6 +7,7 @@ from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.process import Process
from crewai.task import Task
from crewai.utilities.ollama.monkey_patch import apply_monkey_patch
warnings.filterwarnings(
"ignore",
@@ -23,4 +24,5 @@ __all__ = [
"LLM",
"Flow",
"Knowledge",
"apply_monkey_patch",
]

View File

@@ -0,0 +1,67 @@
# Ollama Integration for CrewAI
This module provides integration between CrewAI and Ollama, allowing you to use local LLMs with CrewAI without requiring an OpenAI API key.
## Overview
The integration works by applying a monkey patch to `litellm.completion`, which is used by CrewAI to communicate with LLMs. The monkey patch intercepts calls to Ollama models and redirects them to the local Ollama API instead of going through LiteLLM's normal channels.
## Usage
To use this integration, you need to:
1. Install and run Ollama locally (see [ollama.ai](https://ollama.ai))
2. Pull the desired model (e.g., `ollama pull llama3`)
3. Apply the monkey patch at the beginning of your CrewAI application:
```python
from crewai import Agent, Crew, Task
from crewai.llm import LLM
from crewai import apply_monkey_patch
# Apply the monkey patch
apply_monkey_patch()
# Create an LLM instance with an Ollama model
llm = LLM(model="ollama/llama3", base_url="http://localhost:11434")
# Use the LLM instance with CrewAI
agent = Agent(
role="Local AI Expert",
goal="Process information using a local model",
backstory="An AI assistant running on local hardware.",
llm=llm
)
# Continue with your CrewAI application...
```
## Configuration
The Ollama integration supports the following configuration options:
- `model`: The name of the Ollama model to use, prefixed with "ollama/" (e.g., "ollama/llama3")
- `base_url`: The base URL for the Ollama API (default: "http://localhost:11434")
- `temperature`: The temperature parameter for generation (default: 0.7)
- `stream`: Whether to stream the response (default: False)
## Supported Models
Any model available in your local Ollama installation can be used with this integration. Just prefix the model name with "ollama/" when creating the LLM instance.
## Limitations
- Tool calling is not fully supported with local Ollama models
- Some advanced features like response formatting may not work as expected
- Token counting is estimated rather than exact
## Troubleshooting
If you encounter issues with the Ollama integration, check the following:
1. Make sure Ollama is running locally
2. Verify that you've pulled the model you're trying to use
3. Check that the base_url is correct
4. Look for error messages in the logs
For more information, see the [CrewAI documentation](https://docs.crewai.com/how-to/LLM-Connections/).

View File

@@ -0,0 +1,10 @@
"""
Ollama integration utilities for CrewAI.
This package provides utilities for integrating CrewAI with Ollama,
a local LLM provider.
"""
from .monkey_patch import apply_monkey_patch
__all__ = ["apply_monkey_patch"]

View File

@@ -0,0 +1,236 @@
"""
Monkey patch for litellm.completion to enable local Ollama LLM usage.
This module provides a monkey patch for litellm.completion that allows CrewAI
to work with local Ollama LLM instances without requiring an OpenAI API key.
"""
import json
import logging
import requests
from types import SimpleNamespace
from typing import Dict, Any, List, Generator, Optional, Union
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def query_ollama(
prompt: str,
model: str = "llama3",
base_url: str = "http://localhost:11434",
stream: bool = False,
temperature: float = 0.7,
stop: Optional[List[str]] = None
) -> Union[str, Generator]:
"""
Query Ollama API directly
Args:
prompt: The prompt to send to Ollama
model: The model to use (default: llama3)
base_url: The base URL for Ollama API (default: http://localhost:11434)
stream: Whether to stream the response (default: False)
temperature: Temperature parameter for generation (default: 0.7)
stop: Optional list of stop sequences
Returns:
The response text from Ollama or a generator for streaming
"""
url = f"{base_url}/api/generate"
data = {
"model": model,
"prompt": prompt,
"options": {
"temperature": temperature,
"num_predict": 100,
"stream": stream
}
}
# Add stop sequences if provided
if stop and isinstance(stop, list) and len(stop) > 0:
data["options"]["stop"] = stop
try:
if stream:
# For streaming, return a generator
response = requests.post(url, json=data, stream=True)
response.raise_for_status()
def stream_generator():
for line in response.iter_lines():
if line:
chunk = json.loads(line)
if "response" in chunk:
yield chunk["response"]
if chunk.get("done", False):
break
return stream_generator()
else:
# For non-streaming, return the complete response
response = requests.post(url, json=data)
response.raise_for_status()
return response.json().get("response", "")
except Exception as e:
logger.error(f"Error querying Ollama API: {str(e)}")
return f"Error: {str(e)}"
def extract_prompt_from_messages(messages: List[Dict[str, str]]) -> str:
"""
Extract a prompt from a list of messages
Args:
messages: List of message dictionaries with 'role' and 'content' keys
Returns:
A formatted prompt string
"""
prompt = ""
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role and content:
prompt += f"### {role.capitalize()}:\n{content}\n\n"
return prompt
def apply_monkey_patch() -> bool:
"""
Apply the monkey patch to litellm.completion
This function saves the original litellm.completion function and
replaces it with a custom implementation that handles Ollama models.
Returns:
bool: True if the patch was applied successfully, False otherwise
"""
try:
# Import litellm
import litellm
logger.info("Successfully imported litellm")
# Save the original completion function
original_completion = litellm.completion
logger.info("Saved original litellm.completion function")
# Define the monkey patch function
def custom_completion(*args, **kwargs):
"""Custom implementation of litellm.completion for Ollama"""
model = kwargs.get("model", "")
messages = kwargs.get("messages", [])
temperature = kwargs.get("temperature", 0.7)
stream = kwargs.get("stream", False)
base_url = kwargs.get("base_url", "http://localhost:11434")
stop = kwargs.get("stop", None)
logger.debug(f"Intercepted call to litellm.completion with model: {model}")
# Only intercept calls for Ollama models
if not model.startswith("ollama/"):
logger.debug("Not an Ollama model, calling original litellm.completion")
return original_completion(*args, **kwargs)
# Extract the actual model name from the 'ollama/model' format
ollama_model = model.split("/")[1]
logger.info(f"Handling Ollama model: {ollama_model}")
# Extract prompt from messages
prompt = extract_prompt_from_messages(messages)
logger.debug(f"Generated prompt: {prompt[:100]}...")
# Query Ollama
if stream:
logger.debug("Using streaming mode")
# For streaming, return a generator that yields chunks in the format expected by CrewAI
# First, get the generator from query_ollama
chunks_generator = query_ollama(
prompt,
model=ollama_model,
base_url=base_url,
stream=True,
temperature=temperature,
stop=stop
)
# Then create a wrapper generator that transforms the chunks
def stream_response():
for chunk in chunks_generator:
yield SimpleNamespace(
choices=[
SimpleNamespace(
delta=SimpleNamespace(
content=chunk,
role="assistant"
),
index=0,
finish_reason=None
)
],
usage=None,
model=model
)
# Final chunk with finish_reason and usage
yield SimpleNamespace(
choices=[
SimpleNamespace(
delta=SimpleNamespace(
content="",
role="assistant"
),
index=0,
finish_reason="stop"
)
],
usage=SimpleNamespace(
prompt_tokens=len(prompt.split()),
completion_tokens=len(prompt.split()) * 2, # Estimate
total_tokens=len(prompt.split()) * 3 # Estimate
),
model=model
)
return stream_response()
else:
logger.debug("Using non-streaming mode")
# For non-streaming, return a complete response object
response_text = query_ollama(
prompt,
model=ollama_model,
base_url=base_url,
temperature=temperature,
stop=stop
)
logger.debug(f"Received response: {response_text[:100]}...")
return SimpleNamespace(
choices=[
SimpleNamespace(
message=SimpleNamespace(
content=response_text,
tool_calls=None,
role="assistant"
),
finish_reason="stop",
index=0
)
],
usage=SimpleNamespace(
prompt_tokens=len(prompt.split()),
completion_tokens=len(response_text.split()),
total_tokens=len(prompt.split()) + len(response_text.split())
),
id="ollama-response",
model=model,
created=123456789
)
# Apply the monkey patch
litellm.completion = custom_completion
logger.info("Applied monkey patch to litellm.completion")
return True
except ImportError as e:
logger.error(f"Error importing litellm: {e}")
return False
except Exception as e:
logger.error(f"Unexpected error: {e}")
return False

View File

@@ -0,0 +1,3 @@
"""
Tests for Ollama integration utilities.
"""

View File

@@ -0,0 +1,253 @@
"""
Tests for the Ollama monkey patch utility.
"""
import unittest
from unittest.mock import patch, MagicMock, call
import json
from types import SimpleNamespace
import pytest
from crewai.utilities.ollama.monkey_patch import (
apply_monkey_patch,
query_ollama,
extract_prompt_from_messages
)
class TestOllamaMonkeyPatch(unittest.TestCase):
"""Test cases for the Ollama monkey patch utility."""
def test_extract_prompt_from_messages(self):
"""Test extracting a prompt from a list of messages."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you!"},
{"role": "user", "content": "Tell me about CrewAI."}
]
prompt = extract_prompt_from_messages(messages)
self.assertIn("System:", prompt)
self.assertIn("You are a helpful assistant.", prompt)
self.assertIn("User:", prompt)
self.assertIn("Hello, how are you?", prompt)
self.assertIn("Assistant:", prompt)
self.assertIn("I'm doing well, thank you!", prompt)
self.assertIn("Tell me about CrewAI.", prompt)
@patch('requests.post')
def test_query_ollama_non_streaming(self, mock_post):
"""Test querying Ollama API in non-streaming mode."""
# Mock the response
mock_response = MagicMock()
mock_response.json.return_value = {"response": "This is a test response."}
mock_post.return_value = mock_response
# Call the function
result = query_ollama(
prompt="Test prompt",
model="llama3",
base_url="http://localhost:11434",
stream=False,
temperature=0.5
)
# Verify the result
self.assertEqual(result, "This is a test response.")
# Verify the API call
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
self.assertEqual(args[0], "http://localhost:11434/api/generate")
self.assertEqual(kwargs["json"]["model"], "llama3")
self.assertEqual(kwargs["json"]["prompt"], "Test prompt")
self.assertEqual(kwargs["json"]["options"]["temperature"], 0.5)
self.assertEqual(kwargs["json"]["options"]["stream"], False)
@patch('requests.post')
def test_query_ollama_streaming(self, mock_post):
"""Test querying Ollama API in streaming mode."""
# Mock the response for streaming
mock_response = MagicMock()
mock_response.iter_lines.return_value = [
json.dumps({"response": "This"}).encode(),
json.dumps({"response": " is"}).encode(),
json.dumps({"response": " a"}).encode(),
json.dumps({"response": " test"}).encode(),
json.dumps({"response": " response.", "done": True}).encode()
]
mock_post.return_value = mock_response
# Call the function
result = query_ollama(
prompt="Test prompt",
model="llama3",
base_url="http://localhost:11434",
stream=True,
temperature=0.5
)
# Verify the result is a generator
self.assertTrue(hasattr(result, '__next__'))
# Consume the generator and verify the results
chunks = list(result)
self.assertEqual(chunks, ["This", " is", " a", " test", " response."])
# Verify the API call
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
self.assertEqual(args[0], "http://localhost:11434/api/generate")
self.assertEqual(kwargs["json"]["model"], "llama3")
self.assertEqual(kwargs["json"]["prompt"], "Test prompt")
self.assertEqual(kwargs["json"]["options"]["temperature"], 0.5)
self.assertEqual(kwargs["json"]["options"]["stream"], True)
@patch('requests.post')
def test_query_ollama_with_stop_sequences(self, mock_post):
"""Test querying Ollama API with stop sequences."""
# Mock the response
mock_response = MagicMock()
mock_response.json.return_value = {"response": "This is a test response."}
mock_post.return_value = mock_response
# Call the function with stop sequences
result = query_ollama(
prompt="Test prompt",
model="llama3",
stop=["END", "STOP"]
)
# Verify the API call includes stop sequences
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
self.assertEqual(kwargs["json"]["options"]["stop"], ["END", "STOP"])
@patch('requests.post')
def test_query_ollama_error_handling(self, mock_post):
"""Test error handling in query_ollama."""
# Mock the response to raise an exception
mock_post.side_effect = Exception("Test error")
# Call the function
result = query_ollama(prompt="Test prompt")
# Verify the result contains the error message
self.assertIn("Error:", result)
self.assertIn("Test error", result)
@patch('litellm.completion')
def test_apply_monkey_patch(self, mock_completion):
"""Test applying the monkey patch."""
# Apply the monkey patch
result = apply_monkey_patch()
# Verify the result
self.assertTrue(result)
# Verify that litellm.completion has been replaced
import litellm
self.assertNotEqual(litellm.completion, mock_completion)
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
@patch('litellm.completion')
def test_custom_completion_non_ollama_model(self, mock_original_completion, mock_query_ollama):
"""Test that non-Ollama models are passed to the original completion function."""
# Apply the monkey patch
apply_monkey_patch()
# Import litellm to get the patched completion function
import litellm
# Call the patched completion function with a non-Ollama model
litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello"}]
)
# Verify that the original completion function was called
mock_original_completion.assert_called_once()
# Verify that query_ollama was not called
mock_query_ollama.assert_not_called()
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
@patch('litellm.completion')
def test_custom_completion_ollama_model_non_streaming(self, mock_original_completion, mock_query_ollama):
"""Test the custom completion function with an Ollama model in non-streaming mode."""
# Set up the mock
mock_query_ollama.return_value = "This is a test response."
# Apply the monkey patch
apply_monkey_patch()
# Import litellm to get the patched completion function
import litellm
# Call the patched completion function with an Ollama model
result = litellm.completion(
model="ollama/llama3",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.5
)
# Verify that the original completion function was not called
mock_original_completion.assert_not_called()
# Verify that query_ollama was called
mock_query_ollama.assert_called_once()
# Verify the result structure
self.assertEqual(result.choices[0].message.content, "This is a test response.")
self.assertEqual(result.choices[0].finish_reason, "stop")
self.assertEqual(result.model, "ollama/llama3")
self.assertIsNotNone(result.usage)
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
@patch('litellm.completion')
def test_custom_completion_ollama_model_streaming(self, mock_original_completion, mock_query_ollama):
"""Test the custom completion function with an Ollama model in streaming mode."""
# Set up the mock to return a generator
mock_query_ollama.return_value = (chunk for chunk in ["This", " is", " a", " test", " response."])
# Apply the monkey patch
apply_monkey_patch()
# Import litellm to get the patched completion function
import litellm
# Call the patched completion function with an Ollama model in streaming mode
result = litellm.completion(
model="ollama/llama3",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.5,
stream=True
)
# Verify that the original completion function was not called
mock_original_completion.assert_not_called()
# Verify that query_ollama was called
mock_query_ollama.assert_called_once()
# Verify the result is a generator
self.assertTrue(hasattr(result, '__next__'))
# Consume the generator and verify the structure of each chunk
chunks = list(result)
# Verify we have the expected number of chunks (5 content chunks + 1 final chunk)
self.assertEqual(len(chunks), 6)
# Check the content of the first 5 chunks
for i, expected_content in enumerate(["This", " is", " a", " test", " response."]):
self.assertEqual(chunks[i].choices[0].delta.content, expected_content)
self.assertEqual(chunks[i].choices[0].delta.role, "assistant")
self.assertIsNone(chunks[i].choices[0].finish_reason)
# Check the final chunk
self.assertEqual(chunks[5].choices[0].delta.content, "")
self.assertEqual(chunks[5].choices[0].finish_reason, "stop")
self.assertIsNotNone(chunks[5].usage)