mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-06 14:48:29 +00:00
Compare commits
1 Commits
1.1.0
...
devin/1741
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d049f56986 |
@@ -142,7 +142,9 @@ You can connect to OpenAI-compatible LLMs using either environment variables or
|
||||
|
||||
## Using Local Models with Ollama
|
||||
|
||||
For local models like those provided by Ollama:
|
||||
CrewAI provides two ways to use local models with Ollama:
|
||||
|
||||
### Method 1: Direct Connection (Standard)
|
||||
|
||||
<Steps>
|
||||
<Step title="Download and install Ollama">
|
||||
@@ -165,6 +167,49 @@ For local models like those provided by Ollama:
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
### Method 2: Using the Ollama Monkey Patch (Recommended)
|
||||
|
||||
For a more robust integration with Ollama, CrewAI provides a monkey patch that enhances compatibility and performance:
|
||||
|
||||
<Steps>
|
||||
<Step title="Download and install Ollama">
|
||||
[Click here to download and install Ollama](https://ollama.com/download)
|
||||
</Step>
|
||||
<Step title="Pull the desired model">
|
||||
For example, run `ollama pull llama3` to download the model.
|
||||
</Step>
|
||||
<Step title="Apply the monkey patch">
|
||||
<CodeGroup>
|
||||
```python Code
|
||||
from crewai import Agent, Crew, Task, LLM
|
||||
from crewai import apply_monkey_patch
|
||||
|
||||
# Apply the monkey patch at the beginning of your script
|
||||
apply_monkey_patch()
|
||||
|
||||
# Create an LLM instance with an Ollama model
|
||||
llm = LLM(model="ollama/llama3", base_url="http://localhost:11434")
|
||||
|
||||
# Use the LLM instance with CrewAI
|
||||
agent = Agent(
|
||||
role='Local AI Expert',
|
||||
goal='Process information using a local model',
|
||||
backstory="An AI assistant running on local hardware.",
|
||||
llm=llm
|
||||
)
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
The monkey patch provides several advantages:
|
||||
- Improved handling of streaming responses
|
||||
- Better error handling and logging
|
||||
- More accurate token counting
|
||||
- Enhanced compatibility with CrewAI's features
|
||||
|
||||
For more details, see the [Ollama integration README](https://github.com/crewAIinc/crewAI/blob/main/src/crewai/utilities/ollama/README.md).
|
||||
|
||||
## Changing the Base API URL
|
||||
|
||||
You can change the base API URL for any LLM provider by setting the `base_url` parameter:
|
||||
|
||||
@@ -7,6 +7,7 @@ from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.ollama.monkey_patch import apply_monkey_patch
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
@@ -23,4 +24,5 @@ __all__ = [
|
||||
"LLM",
|
||||
"Flow",
|
||||
"Knowledge",
|
||||
"apply_monkey_patch",
|
||||
]
|
||||
|
||||
67
src/crewai/utilities/ollama/README.md
Normal file
67
src/crewai/utilities/ollama/README.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# Ollama Integration for CrewAI
|
||||
|
||||
This module provides integration between CrewAI and Ollama, allowing you to use local LLMs with CrewAI without requiring an OpenAI API key.
|
||||
|
||||
## Overview
|
||||
|
||||
The integration works by applying a monkey patch to `litellm.completion`, which is used by CrewAI to communicate with LLMs. The monkey patch intercepts calls to Ollama models and redirects them to the local Ollama API instead of going through LiteLLM's normal channels.
|
||||
|
||||
## Usage
|
||||
|
||||
To use this integration, you need to:
|
||||
|
||||
1. Install and run Ollama locally (see [ollama.ai](https://ollama.ai))
|
||||
2. Pull the desired model (e.g., `ollama pull llama3`)
|
||||
3. Apply the monkey patch at the beginning of your CrewAI application:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.llm import LLM
|
||||
from crewai import apply_monkey_patch
|
||||
|
||||
# Apply the monkey patch
|
||||
apply_monkey_patch()
|
||||
|
||||
# Create an LLM instance with an Ollama model
|
||||
llm = LLM(model="ollama/llama3", base_url="http://localhost:11434")
|
||||
|
||||
# Use the LLM instance with CrewAI
|
||||
agent = Agent(
|
||||
role="Local AI Expert",
|
||||
goal="Process information using a local model",
|
||||
backstory="An AI assistant running on local hardware.",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
# Continue with your CrewAI application...
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
The Ollama integration supports the following configuration options:
|
||||
|
||||
- `model`: The name of the Ollama model to use, prefixed with "ollama/" (e.g., "ollama/llama3")
|
||||
- `base_url`: The base URL for the Ollama API (default: "http://localhost:11434")
|
||||
- `temperature`: The temperature parameter for generation (default: 0.7)
|
||||
- `stream`: Whether to stream the response (default: False)
|
||||
|
||||
## Supported Models
|
||||
|
||||
Any model available in your local Ollama installation can be used with this integration. Just prefix the model name with "ollama/" when creating the LLM instance.
|
||||
|
||||
## Limitations
|
||||
|
||||
- Tool calling is not fully supported with local Ollama models
|
||||
- Some advanced features like response formatting may not work as expected
|
||||
- Token counting is estimated rather than exact
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you encounter issues with the Ollama integration, check the following:
|
||||
|
||||
1. Make sure Ollama is running locally
|
||||
2. Verify that you've pulled the model you're trying to use
|
||||
3. Check that the base_url is correct
|
||||
4. Look for error messages in the logs
|
||||
|
||||
For more information, see the [CrewAI documentation](https://docs.crewai.com/how-to/LLM-Connections/).
|
||||
10
src/crewai/utilities/ollama/__init__.py
Normal file
10
src/crewai/utilities/ollama/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Ollama integration utilities for CrewAI.
|
||||
|
||||
This package provides utilities for integrating CrewAI with Ollama,
|
||||
a local LLM provider.
|
||||
"""
|
||||
|
||||
from .monkey_patch import apply_monkey_patch
|
||||
|
||||
__all__ = ["apply_monkey_patch"]
|
||||
236
src/crewai/utilities/ollama/monkey_patch.py
Normal file
236
src/crewai/utilities/ollama/monkey_patch.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
Monkey patch for litellm.completion to enable local Ollama LLM usage.
|
||||
|
||||
This module provides a monkey patch for litellm.completion that allows CrewAI
|
||||
to work with local Ollama LLM instances without requiring an OpenAI API key.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, Any, List, Generator, Optional, Union
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def query_ollama(
|
||||
prompt: str,
|
||||
model: str = "llama3",
|
||||
base_url: str = "http://localhost:11434",
|
||||
stream: bool = False,
|
||||
temperature: float = 0.7,
|
||||
stop: Optional[List[str]] = None
|
||||
) -> Union[str, Generator]:
|
||||
"""
|
||||
Query Ollama API directly
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to Ollama
|
||||
model: The model to use (default: llama3)
|
||||
base_url: The base URL for Ollama API (default: http://localhost:11434)
|
||||
stream: Whether to stream the response (default: False)
|
||||
temperature: Temperature parameter for generation (default: 0.7)
|
||||
stop: Optional list of stop sequences
|
||||
|
||||
Returns:
|
||||
The response text from Ollama or a generator for streaming
|
||||
"""
|
||||
url = f"{base_url}/api/generate"
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
"num_predict": 100,
|
||||
"stream": stream
|
||||
}
|
||||
}
|
||||
|
||||
# Add stop sequences if provided
|
||||
if stop and isinstance(stop, list) and len(stop) > 0:
|
||||
data["options"]["stop"] = stop
|
||||
|
||||
try:
|
||||
if stream:
|
||||
# For streaming, return a generator
|
||||
response = requests.post(url, json=data, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
def stream_generator():
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
chunk = json.loads(line)
|
||||
if "response" in chunk:
|
||||
yield chunk["response"]
|
||||
if chunk.get("done", False):
|
||||
break
|
||||
return stream_generator()
|
||||
else:
|
||||
# For non-streaming, return the complete response
|
||||
response = requests.post(url, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json().get("response", "")
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying Ollama API: {str(e)}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def extract_prompt_from_messages(messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Extract a prompt from a list of messages
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content' keys
|
||||
|
||||
Returns:
|
||||
A formatted prompt string
|
||||
"""
|
||||
prompt = ""
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
if role and content:
|
||||
prompt += f"### {role.capitalize()}:\n{content}\n\n"
|
||||
return prompt
|
||||
|
||||
def apply_monkey_patch() -> bool:
|
||||
"""
|
||||
Apply the monkey patch to litellm.completion
|
||||
|
||||
This function saves the original litellm.completion function and
|
||||
replaces it with a custom implementation that handles Ollama models.
|
||||
|
||||
Returns:
|
||||
bool: True if the patch was applied successfully, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Import litellm
|
||||
import litellm
|
||||
logger.info("Successfully imported litellm")
|
||||
|
||||
# Save the original completion function
|
||||
original_completion = litellm.completion
|
||||
logger.info("Saved original litellm.completion function")
|
||||
|
||||
# Define the monkey patch function
|
||||
def custom_completion(*args, **kwargs):
|
||||
"""Custom implementation of litellm.completion for Ollama"""
|
||||
model = kwargs.get("model", "")
|
||||
messages = kwargs.get("messages", [])
|
||||
temperature = kwargs.get("temperature", 0.7)
|
||||
stream = kwargs.get("stream", False)
|
||||
base_url = kwargs.get("base_url", "http://localhost:11434")
|
||||
stop = kwargs.get("stop", None)
|
||||
|
||||
logger.debug(f"Intercepted call to litellm.completion with model: {model}")
|
||||
|
||||
# Only intercept calls for Ollama models
|
||||
if not model.startswith("ollama/"):
|
||||
logger.debug("Not an Ollama model, calling original litellm.completion")
|
||||
return original_completion(*args, **kwargs)
|
||||
|
||||
# Extract the actual model name from the 'ollama/model' format
|
||||
ollama_model = model.split("/")[1]
|
||||
logger.info(f"Handling Ollama model: {ollama_model}")
|
||||
|
||||
# Extract prompt from messages
|
||||
prompt = extract_prompt_from_messages(messages)
|
||||
|
||||
logger.debug(f"Generated prompt: {prompt[:100]}...")
|
||||
|
||||
# Query Ollama
|
||||
if stream:
|
||||
logger.debug("Using streaming mode")
|
||||
# For streaming, return a generator that yields chunks in the format expected by CrewAI
|
||||
# First, get the generator from query_ollama
|
||||
chunks_generator = query_ollama(
|
||||
prompt,
|
||||
model=ollama_model,
|
||||
base_url=base_url,
|
||||
stream=True,
|
||||
temperature=temperature,
|
||||
stop=stop
|
||||
)
|
||||
|
||||
# Then create a wrapper generator that transforms the chunks
|
||||
def stream_response():
|
||||
for chunk in chunks_generator:
|
||||
yield SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
delta=SimpleNamespace(
|
||||
content=chunk,
|
||||
role="assistant"
|
||||
),
|
||||
index=0,
|
||||
finish_reason=None
|
||||
)
|
||||
],
|
||||
usage=None,
|
||||
model=model
|
||||
)
|
||||
# Final chunk with finish_reason and usage
|
||||
yield SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
delta=SimpleNamespace(
|
||||
content="",
|
||||
role="assistant"
|
||||
),
|
||||
index=0,
|
||||
finish_reason="stop"
|
||||
)
|
||||
],
|
||||
usage=SimpleNamespace(
|
||||
prompt_tokens=len(prompt.split()),
|
||||
completion_tokens=len(prompt.split()) * 2, # Estimate
|
||||
total_tokens=len(prompt.split()) * 3 # Estimate
|
||||
),
|
||||
model=model
|
||||
)
|
||||
return stream_response()
|
||||
else:
|
||||
logger.debug("Using non-streaming mode")
|
||||
# For non-streaming, return a complete response object
|
||||
response_text = query_ollama(
|
||||
prompt,
|
||||
model=ollama_model,
|
||||
base_url=base_url,
|
||||
temperature=temperature,
|
||||
stop=stop
|
||||
)
|
||||
logger.debug(f"Received response: {response_text[:100]}...")
|
||||
return SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
message=SimpleNamespace(
|
||||
content=response_text,
|
||||
tool_calls=None,
|
||||
role="assistant"
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0
|
||||
)
|
||||
],
|
||||
usage=SimpleNamespace(
|
||||
prompt_tokens=len(prompt.split()),
|
||||
completion_tokens=len(response_text.split()),
|
||||
total_tokens=len(prompt.split()) + len(response_text.split())
|
||||
),
|
||||
id="ollama-response",
|
||||
model=model,
|
||||
created=123456789
|
||||
)
|
||||
|
||||
# Apply the monkey patch
|
||||
litellm.completion = custom_completion
|
||||
logger.info("Applied monkey patch to litellm.completion")
|
||||
|
||||
return True
|
||||
except ImportError as e:
|
||||
logger.error(f"Error importing litellm: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
return False
|
||||
3
tests/utilities/ollama/__init__.py
Normal file
3
tests/utilities/ollama/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Tests for Ollama integration utilities.
|
||||
"""
|
||||
253
tests/utilities/ollama/test_monkey_patch.py
Normal file
253
tests/utilities/ollama/test_monkey_patch.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Tests for the Ollama monkey patch utility.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.ollama.monkey_patch import (
|
||||
apply_monkey_patch,
|
||||
query_ollama,
|
||||
extract_prompt_from_messages
|
||||
)
|
||||
|
||||
|
||||
class TestOllamaMonkeyPatch(unittest.TestCase):
|
||||
"""Test cases for the Ollama monkey patch utility."""
|
||||
|
||||
def test_extract_prompt_from_messages(self):
|
||||
"""Test extracting a prompt from a list of messages."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you!"},
|
||||
{"role": "user", "content": "Tell me about CrewAI."}
|
||||
]
|
||||
|
||||
prompt = extract_prompt_from_messages(messages)
|
||||
|
||||
self.assertIn("System:", prompt)
|
||||
self.assertIn("You are a helpful assistant.", prompt)
|
||||
self.assertIn("User:", prompt)
|
||||
self.assertIn("Hello, how are you?", prompt)
|
||||
self.assertIn("Assistant:", prompt)
|
||||
self.assertIn("I'm doing well, thank you!", prompt)
|
||||
self.assertIn("Tell me about CrewAI.", prompt)
|
||||
|
||||
@patch('requests.post')
|
||||
def test_query_ollama_non_streaming(self, mock_post):
|
||||
"""Test querying Ollama API in non-streaming mode."""
|
||||
# Mock the response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"response": "This is a test response."}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Call the function
|
||||
result = query_ollama(
|
||||
prompt="Test prompt",
|
||||
model="llama3",
|
||||
base_url="http://localhost:11434",
|
||||
stream=False,
|
||||
temperature=0.5
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
self.assertEqual(result, "This is a test response.")
|
||||
|
||||
# Verify the API call
|
||||
mock_post.assert_called_once()
|
||||
args, kwargs = mock_post.call_args
|
||||
self.assertEqual(args[0], "http://localhost:11434/api/generate")
|
||||
self.assertEqual(kwargs["json"]["model"], "llama3")
|
||||
self.assertEqual(kwargs["json"]["prompt"], "Test prompt")
|
||||
self.assertEqual(kwargs["json"]["options"]["temperature"], 0.5)
|
||||
self.assertEqual(kwargs["json"]["options"]["stream"], False)
|
||||
|
||||
@patch('requests.post')
|
||||
def test_query_ollama_streaming(self, mock_post):
|
||||
"""Test querying Ollama API in streaming mode."""
|
||||
# Mock the response for streaming
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [
|
||||
json.dumps({"response": "This"}).encode(),
|
||||
json.dumps({"response": " is"}).encode(),
|
||||
json.dumps({"response": " a"}).encode(),
|
||||
json.dumps({"response": " test"}).encode(),
|
||||
json.dumps({"response": " response.", "done": True}).encode()
|
||||
]
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Call the function
|
||||
result = query_ollama(
|
||||
prompt="Test prompt",
|
||||
model="llama3",
|
||||
base_url="http://localhost:11434",
|
||||
stream=True,
|
||||
temperature=0.5
|
||||
)
|
||||
|
||||
# Verify the result is a generator
|
||||
self.assertTrue(hasattr(result, '__next__'))
|
||||
|
||||
# Consume the generator and verify the results
|
||||
chunks = list(result)
|
||||
self.assertEqual(chunks, ["This", " is", " a", " test", " response."])
|
||||
|
||||
# Verify the API call
|
||||
mock_post.assert_called_once()
|
||||
args, kwargs = mock_post.call_args
|
||||
self.assertEqual(args[0], "http://localhost:11434/api/generate")
|
||||
self.assertEqual(kwargs["json"]["model"], "llama3")
|
||||
self.assertEqual(kwargs["json"]["prompt"], "Test prompt")
|
||||
self.assertEqual(kwargs["json"]["options"]["temperature"], 0.5)
|
||||
self.assertEqual(kwargs["json"]["options"]["stream"], True)
|
||||
|
||||
@patch('requests.post')
|
||||
def test_query_ollama_with_stop_sequences(self, mock_post):
|
||||
"""Test querying Ollama API with stop sequences."""
|
||||
# Mock the response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"response": "This is a test response."}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Call the function with stop sequences
|
||||
result = query_ollama(
|
||||
prompt="Test prompt",
|
||||
model="llama3",
|
||||
stop=["END", "STOP"]
|
||||
)
|
||||
|
||||
# Verify the API call includes stop sequences
|
||||
mock_post.assert_called_once()
|
||||
args, kwargs = mock_post.call_args
|
||||
self.assertEqual(kwargs["json"]["options"]["stop"], ["END", "STOP"])
|
||||
|
||||
@patch('requests.post')
|
||||
def test_query_ollama_error_handling(self, mock_post):
|
||||
"""Test error handling in query_ollama."""
|
||||
# Mock the response to raise an exception
|
||||
mock_post.side_effect = Exception("Test error")
|
||||
|
||||
# Call the function
|
||||
result = query_ollama(prompt="Test prompt")
|
||||
|
||||
# Verify the result contains the error message
|
||||
self.assertIn("Error:", result)
|
||||
self.assertIn("Test error", result)
|
||||
|
||||
@patch('litellm.completion')
|
||||
def test_apply_monkey_patch(self, mock_completion):
|
||||
"""Test applying the monkey patch."""
|
||||
# Apply the monkey patch
|
||||
result = apply_monkey_patch()
|
||||
|
||||
# Verify the result
|
||||
self.assertTrue(result)
|
||||
|
||||
# Verify that litellm.completion has been replaced
|
||||
import litellm
|
||||
self.assertNotEqual(litellm.completion, mock_completion)
|
||||
|
||||
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
|
||||
@patch('litellm.completion')
|
||||
def test_custom_completion_non_ollama_model(self, mock_original_completion, mock_query_ollama):
|
||||
"""Test that non-Ollama models are passed to the original completion function."""
|
||||
# Apply the monkey patch
|
||||
apply_monkey_patch()
|
||||
|
||||
# Import litellm to get the patched completion function
|
||||
import litellm
|
||||
|
||||
# Call the patched completion function with a non-Ollama model
|
||||
litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
# Verify that the original completion function was called
|
||||
mock_original_completion.assert_called_once()
|
||||
|
||||
# Verify that query_ollama was not called
|
||||
mock_query_ollama.assert_not_called()
|
||||
|
||||
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
|
||||
@patch('litellm.completion')
|
||||
def test_custom_completion_ollama_model_non_streaming(self, mock_original_completion, mock_query_ollama):
|
||||
"""Test the custom completion function with an Ollama model in non-streaming mode."""
|
||||
# Set up the mock
|
||||
mock_query_ollama.return_value = "This is a test response."
|
||||
|
||||
# Apply the monkey patch
|
||||
apply_monkey_patch()
|
||||
|
||||
# Import litellm to get the patched completion function
|
||||
import litellm
|
||||
|
||||
# Call the patched completion function with an Ollama model
|
||||
result = litellm.completion(
|
||||
model="ollama/llama3",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
temperature=0.5
|
||||
)
|
||||
|
||||
# Verify that the original completion function was not called
|
||||
mock_original_completion.assert_not_called()
|
||||
|
||||
# Verify that query_ollama was called
|
||||
mock_query_ollama.assert_called_once()
|
||||
|
||||
# Verify the result structure
|
||||
self.assertEqual(result.choices[0].message.content, "This is a test response.")
|
||||
self.assertEqual(result.choices[0].finish_reason, "stop")
|
||||
self.assertEqual(result.model, "ollama/llama3")
|
||||
self.assertIsNotNone(result.usage)
|
||||
|
||||
@patch('crewai.utilities.ollama.monkey_patch.query_ollama')
|
||||
@patch('litellm.completion')
|
||||
def test_custom_completion_ollama_model_streaming(self, mock_original_completion, mock_query_ollama):
|
||||
"""Test the custom completion function with an Ollama model in streaming mode."""
|
||||
# Set up the mock to return a generator
|
||||
mock_query_ollama.return_value = (chunk for chunk in ["This", " is", " a", " test", " response."])
|
||||
|
||||
# Apply the monkey patch
|
||||
apply_monkey_patch()
|
||||
|
||||
# Import litellm to get the patched completion function
|
||||
import litellm
|
||||
|
||||
# Call the patched completion function with an Ollama model in streaming mode
|
||||
result = litellm.completion(
|
||||
model="ollama/llama3",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
temperature=0.5,
|
||||
stream=True
|
||||
)
|
||||
|
||||
# Verify that the original completion function was not called
|
||||
mock_original_completion.assert_not_called()
|
||||
|
||||
# Verify that query_ollama was called
|
||||
mock_query_ollama.assert_called_once()
|
||||
|
||||
# Verify the result is a generator
|
||||
self.assertTrue(hasattr(result, '__next__'))
|
||||
|
||||
# Consume the generator and verify the structure of each chunk
|
||||
chunks = list(result)
|
||||
|
||||
# Verify we have the expected number of chunks (5 content chunks + 1 final chunk)
|
||||
self.assertEqual(len(chunks), 6)
|
||||
|
||||
# Check the content of the first 5 chunks
|
||||
for i, expected_content in enumerate(["This", " is", " a", " test", " response."]):
|
||||
self.assertEqual(chunks[i].choices[0].delta.content, expected_content)
|
||||
self.assertEqual(chunks[i].choices[0].delta.role, "assistant")
|
||||
self.assertIsNone(chunks[i].choices[0].finish_reason)
|
||||
|
||||
# Check the final chunk
|
||||
self.assertEqual(chunks[5].choices[0].delta.content, "")
|
||||
self.assertEqual(chunks[5].choices[0].finish_reason, "stop")
|
||||
self.assertIsNotNone(chunks[5].usage)
|
||||
Reference in New Issue
Block a user