mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-30 10:38:14 +00:00
Compare commits
5 Commits
0.159.0
...
devin/1755
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac7390b287 | ||
|
|
6f139dff06 | ||
|
|
04a03d332f | ||
|
|
992e093610 | ||
|
|
07f8e73958 |
@@ -539,16 +539,71 @@ crew = Crew(
|
||||
)
|
||||
```
|
||||
|
||||
### Mem0 Provider
|
||||
|
||||
Short-Term Memory and Entity Memory both supports a tight integration with both Mem0 OSS and Mem0 Client as a provider. Here is how you can use Mem0 as a provider.
|
||||
|
||||
```python
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.entity_entity_memory import EntityMemory
|
||||
|
||||
mem0_oss_embedder_config = {
|
||||
"provider": "mem0",
|
||||
"config": {
|
||||
"user_id": "john",
|
||||
"local_mem0_config": {
|
||||
"vector_store": {"provider": "qdrant","config": {"host": "localhost", "port": 6333}},
|
||||
"llm": {"provider": "openai","config": {"api_key": "your-api-key", "model": "gpt-4"}},
|
||||
"embedder": {"provider": "openai","config": {"api_key": "your-api-key", "model": "text-embedding-3-small"}}
|
||||
},
|
||||
"infer": True # Optional defaults to True
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
mem0_client_embedder_config = {
|
||||
"provider": "mem0",
|
||||
"config": {
|
||||
"user_id": "john",
|
||||
"org_id": "my_org_id", # Optional
|
||||
"project_id": "my_project_id", # Optional
|
||||
"api_key": "custom-api-key" # Optional - overrides env var
|
||||
"run_id": "my_run_id", # Optional - for short-term memory
|
||||
"includes": "include1", # Optional
|
||||
"excludes": "exclude1", # Optional
|
||||
"infer": True # Optional defaults to True
|
||||
"custom_categories": new_categories # Optional - custom categories for user memory
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
short_term_memory_mem0_oss = ShortTermMemory(embedder_config=mem0_oss_embedder_config) # Short Term Memory with Mem0 OSS
|
||||
short_term_memory_mem0_client = ShortTermMemory(embedder_config=mem0_client_embedder_config) # Short Term Memory with Mem0 Client
|
||||
entity_memory_mem0_oss = EntityMemory(embedder_config=mem0_oss_embedder_config) # Entity Memory with Mem0 OSS
|
||||
entity_memory_mem0_client = EntityMemory(embedder_config=mem0_client_embedder_config) # Short Term Memory with Mem0 Client
|
||||
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
short_term_memory=short_term_memory_mem0_oss, # or short_term_memory_mem0_client
|
||||
entity_memory=entity_memory_mem0_oss # or entity_memory_mem0_client
|
||||
)
|
||||
```
|
||||
|
||||
### Choosing the Right Embedding Provider
|
||||
|
||||
| Provider | Best For | Pros | Cons |
|
||||
|:---------|:----------|:------|:------|
|
||||
| **OpenAI** | General use, reliability | High quality, well-tested | Cost, requires API key |
|
||||
| **Ollama** | Privacy, cost savings | Free, local, private | Requires local setup |
|
||||
| **Google AI** | Google ecosystem | Good performance | Requires Google account |
|
||||
| **Azure OpenAI** | Enterprise, compliance | Enterprise features | Complex setup |
|
||||
| **Cohere** | Multilingual content | Great language support | Specialized use case |
|
||||
| **VoyageAI** | Retrieval tasks | Optimized for search | Newer provider |
|
||||
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
|
||||
Below is a comparison to help you decide:
|
||||
|
||||
| Provider | Best For | Pros | Cons |
|
||||
| -------------- | ------------------------------ | --------------------------------- | ------------------------- |
|
||||
| **OpenAI** | General use, high reliability | High quality, widely tested | Paid service, API key required |
|
||||
| **Ollama** | Privacy-focused, cost savings | Free, runs locally, fully private | Requires local installation/setup |
|
||||
| **Google AI** | Integration in Google ecosystem| Strong performance, good support | Google account required |
|
||||
| **Azure OpenAI** | Enterprise & compliance needs| Enterprise-grade features, security | More complex setup process |
|
||||
| **Cohere** | Multilingual content handling | Excellent language support | More niche use cases |
|
||||
| **VoyageAI** | Information retrieval & search | Optimized for retrieval tasks | Relatively new provider |
|
||||
| **Mem0** | Per-user personalization | Search-optimized embeddings | Paid service, API key required |
|
||||
|
||||
|
||||
### Environment Variable Configuration
|
||||
|
||||
|
||||
@@ -346,7 +346,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
try:
|
||||
self.knowledge_search_query = self._get_knowledge_search_query(
|
||||
task_prompt
|
||||
task_prompt, context
|
||||
)
|
||||
if self.knowledge_search_query:
|
||||
# Quering agent specific knowledge
|
||||
@@ -722,8 +722,8 @@ class Agent(BaseAgent):
|
||||
def set_fingerprint(self, fingerprint: Fingerprint):
|
||||
self.security_config.fingerprint = fingerprint
|
||||
|
||||
def _get_knowledge_search_query(self, task_prompt: str) -> str | None:
|
||||
"""Generate a search query for the knowledge base based on the task description."""
|
||||
def _get_knowledge_search_query(self, task_prompt: str, context: Optional[str] = None) -> str | None:
|
||||
"""Generate a search query for the knowledge base based on the task description and context."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeQueryStartedEvent(
|
||||
@@ -731,9 +731,16 @@ class Agent(BaseAgent):
|
||||
agent=self,
|
||||
),
|
||||
)
|
||||
query = self.i18n.slice("knowledge_search_query").format(
|
||||
task_prompt=task_prompt
|
||||
)
|
||||
|
||||
if context:
|
||||
query = self.i18n.slice("knowledge_search_query_with_context").format(
|
||||
task_prompt=task_prompt, context=context
|
||||
)
|
||||
else:
|
||||
query = self.i18n.slice("knowledge_search_query").format(
|
||||
task_prompt=task_prompt
|
||||
)
|
||||
|
||||
rewriter_prompt = self.i18n.slice("knowledge_search_query_system_prompt")
|
||||
if not isinstance(self.llm, BaseLLM):
|
||||
self._logger.log(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
@@ -21,6 +21,7 @@ class CrewAgentExecutorMixin:
|
||||
task: "Task"
|
||||
iterations: int
|
||||
max_iter: int
|
||||
messages: List[Dict[str, str]]
|
||||
_i18n: I18N
|
||||
_printer: Printer = Printer()
|
||||
|
||||
@@ -62,6 +63,7 @@ class CrewAgentExecutorMixin:
|
||||
value=output.text,
|
||||
metadata={
|
||||
"description": self.task.description,
|
||||
"messages": self.messages,
|
||||
},
|
||||
agent=self.agent.role,
|
||||
)
|
||||
@@ -127,7 +129,6 @@ class CrewAgentExecutorMixin:
|
||||
def _ask_human_input(self, final_answer: str) -> str:
|
||||
"""Prompt human input with mode-appropriate messaging."""
|
||||
event_listener.formatter.pause_live_updates()
|
||||
|
||||
try:
|
||||
self._printer.print(
|
||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from .utils import TokenManager
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_auth_token() -> str:
|
||||
"""Get the authentication token."""
|
||||
access_token = TokenManager().get_token()
|
||||
if not access_token:
|
||||
raise Exception("No token found, make sure you are logged in")
|
||||
raise AuthError("No token found, make sure you are logged in")
|
||||
return access_token
|
||||
|
||||
@@ -18,6 +18,7 @@ class PlusAPI:
|
||||
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
|
||||
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
|
||||
TRACING_RESOURCE = "/crewai_plus/api/v1/tracing"
|
||||
EPHEMERAL_TRACING_RESOURCE = "/crewai_plus/api/v1/tracing/ephemeral"
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
@@ -124,6 +125,11 @@ class PlusAPI:
|
||||
"POST", f"{self.TRACING_RESOURCE}/batches", json=payload
|
||||
)
|
||||
|
||||
def initialize_ephemeral_trace_batch(self, payload) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST", f"{self.EPHEMERAL_TRACING_RESOURCE}/batches", json=payload
|
||||
)
|
||||
|
||||
def send_trace_events(self, trace_batch_id: str, payload) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
@@ -131,9 +137,27 @@ class PlusAPI:
|
||||
json=payload,
|
||||
)
|
||||
|
||||
def send_ephemeral_trace_events(
|
||||
self, trace_batch_id: str, payload
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
def finalize_trace_batch(self, trace_batch_id: str, payload) -> requests.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
def finalize_ephemeral_trace_batch(
|
||||
self, trace_batch_id: str, payload
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
@@ -77,7 +77,10 @@ from crewai.utilities.events.listeners.tracing.trace_listener import (
|
||||
)
|
||||
|
||||
|
||||
from crewai.utilities.events.listeners.tracing.utils import is_tracing_enabled
|
||||
from crewai.utilities.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
on_first_execution_tracing_confirmation,
|
||||
)
|
||||
from crewai.utilities.formatter import (
|
||||
aggregate_raw_outputs_from_task_outputs,
|
||||
aggregate_raw_outputs_from_tasks,
|
||||
@@ -283,8 +286,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
self._cache_handler = CacheHandler()
|
||||
event_listener = EventListener()
|
||||
if on_first_execution_tracing_confirmation():
|
||||
self.tracing = True
|
||||
|
||||
if is_tracing_enabled() or self.tracing:
|
||||
trace_listener = TraceCollectionListener(tracing=self.tracing)
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
event_listener.verbose = self.verbose
|
||||
event_listener.formatter.verbose = self.verbose
|
||||
|
||||
@@ -38,7 +38,10 @@ from crewai.utilities.events.flow_events import (
|
||||
from crewai.utilities.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
from crewai.utilities.events.listeners.tracing.utils import is_tracing_enabled
|
||||
from crewai.utilities.events.listeners.tracing.utils import (
|
||||
is_tracing_enabled,
|
||||
on_first_execution_tracing_confirmation,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -476,8 +479,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
self.tracing = tracing
|
||||
if is_tracing_enabled() or tracing:
|
||||
trace_listener = TraceCollectionListener(tracing=tracing)
|
||||
if (
|
||||
on_first_execution_tracing_confirmation()
|
||||
or is_tracing_enabled()
|
||||
or self.tracing
|
||||
):
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
# Apply any additional kwargs
|
||||
if kwargs:
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
"lite_agent_system_prompt_without_tools": "You are {role}. {backstory}\nYour personal goal is: {goal}\n\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!",
|
||||
"lite_agent_response_format": "\nIMPORTANT: Your final answer MUST contain all the information requested in the following format: {response_format}\n\nIMPORTANT: Ensure the final output does not include any code block markers like ```json or ```python.",
|
||||
"knowledge_search_query": "The original query is: {task_prompt}.",
|
||||
"knowledge_search_query_with_context": "The original query is: {task_prompt}.\n\nContext from previous tasks:\n{context}",
|
||||
"knowledge_search_query_system_prompt": "Your goal is to rewrite the user query so that it is optimized for retrieval from a vector database. Consider how the query will be used to find relevant documents, and aim to make it more specific and context-aware. \n\n Do not include any other text than the rewritten query, especially any preamble or postamble and only add expected output format if its relevant to the rewritten query. \n\n Focus on the key words of the intended task and to retrieve the most relevant information. \n\n There will be some extra context provided that might need to be removed such as expected_output formats structured_outputs and other instructions."
|
||||
},
|
||||
"errors": {
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from crewai.utilities.constants import CREWAI_BASE_URL
|
||||
from crewai.cli.authentication.token import get_auth_token
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
@@ -41,14 +41,21 @@ class TraceBatchManager:
|
||||
"""Single responsibility: Manage batches and event buffering"""
|
||||
|
||||
def __init__(self):
|
||||
self.plus_api = PlusAPI(api_key=get_auth_token())
|
||||
try:
|
||||
self.plus_api = PlusAPI(api_key=get_auth_token())
|
||||
except AuthError:
|
||||
self.plus_api = PlusAPI(api_key="")
|
||||
|
||||
self.trace_batch_id: Optional[str] = None # Backend ID
|
||||
self.current_batch: Optional[TraceBatch] = None
|
||||
self.event_buffer: List[TraceEvent] = []
|
||||
self.execution_start_times: Dict[str, datetime] = {}
|
||||
|
||||
def initialize_batch(
|
||||
self, user_context: Dict[str, str], execution_metadata: Dict[str, Any]
|
||||
self,
|
||||
user_context: Dict[str, str],
|
||||
execution_metadata: Dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
) -> TraceBatch:
|
||||
"""Initialize a new trace batch"""
|
||||
self.current_batch = TraceBatch(
|
||||
@@ -57,13 +64,15 @@ class TraceBatchManager:
|
||||
self.event_buffer.clear()
|
||||
|
||||
self.record_start_time("execution")
|
||||
|
||||
self._initialize_backend_batch(user_context, execution_metadata)
|
||||
self._initialize_backend_batch(user_context, execution_metadata, use_ephemeral)
|
||||
|
||||
return self.current_batch
|
||||
|
||||
def _initialize_backend_batch(
|
||||
self, user_context: Dict[str, str], execution_metadata: Dict[str, Any]
|
||||
self,
|
||||
user_context: Dict[str, str],
|
||||
execution_metadata: Dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
):
|
||||
"""Send batch initialization to backend"""
|
||||
|
||||
@@ -74,6 +83,7 @@ class TraceBatchManager:
|
||||
payload = {
|
||||
"trace_id": self.current_batch.batch_id,
|
||||
"execution_type": execution_metadata.get("execution_type", "crew"),
|
||||
"user_identifier": execution_metadata.get("user_context", None),
|
||||
"execution_context": {
|
||||
"crew_fingerprint": execution_metadata.get("crew_fingerprint"),
|
||||
"crew_name": execution_metadata.get("crew_name", None),
|
||||
@@ -91,12 +101,22 @@ class TraceBatchManager:
|
||||
"execution_started_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
}
|
||||
if use_ephemeral:
|
||||
payload["ephemeral_trace_id"] = self.current_batch.batch_id
|
||||
|
||||
response = self.plus_api.initialize_trace_batch(payload)
|
||||
response = (
|
||||
self.plus_api.initialize_ephemeral_trace_batch(payload)
|
||||
if use_ephemeral
|
||||
else self.plus_api.initialize_trace_batch(payload)
|
||||
)
|
||||
|
||||
if response.status_code == 201 or response.status_code == 200:
|
||||
response_data = response.json()
|
||||
self.trace_batch_id = response_data["trace_id"]
|
||||
self.trace_batch_id = (
|
||||
response_data["trace_id"]
|
||||
if not use_ephemeral
|
||||
else response_data["ephemeral_trace_id"]
|
||||
)
|
||||
console = Console()
|
||||
panel = Panel(
|
||||
f"✅ Trace batch initialized with session ID: {self.trace_batch_id}",
|
||||
@@ -116,7 +136,7 @@ class TraceBatchManager:
|
||||
"""Add event to buffer"""
|
||||
self.event_buffer.append(trace_event)
|
||||
|
||||
def _send_events_to_backend(self):
|
||||
def _send_events_to_backend(self, ephemeral: bool = True):
|
||||
"""Send buffered events to backend"""
|
||||
if not self.plus_api or not self.trace_batch_id or not self.event_buffer:
|
||||
return
|
||||
@@ -134,7 +154,11 @@ class TraceBatchManager:
|
||||
if not self.trace_batch_id:
|
||||
raise Exception("❌ Trace batch ID not found")
|
||||
|
||||
response = self.plus_api.send_trace_events(self.trace_batch_id, payload)
|
||||
response = (
|
||||
self.plus_api.send_ephemeral_trace_events(self.trace_batch_id, payload)
|
||||
if ephemeral
|
||||
else self.plus_api.send_trace_events(self.trace_batch_id, payload)
|
||||
)
|
||||
|
||||
if response.status_code == 200 or response.status_code == 201:
|
||||
self.event_buffer.clear()
|
||||
@@ -146,15 +170,15 @@ class TraceBatchManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error sending events to backend: {str(e)}")
|
||||
|
||||
def finalize_batch(self) -> Optional[TraceBatch]:
|
||||
def finalize_batch(self, ephemeral: bool = True) -> Optional[TraceBatch]:
|
||||
"""Finalize batch and return it for sending"""
|
||||
if not self.current_batch:
|
||||
return None
|
||||
|
||||
if self.event_buffer:
|
||||
self._send_events_to_backend()
|
||||
self._send_events_to_backend(ephemeral)
|
||||
|
||||
self._finalize_backend_batch()
|
||||
self._finalize_backend_batch(ephemeral)
|
||||
|
||||
self.current_batch.events = self.event_buffer.copy()
|
||||
|
||||
@@ -168,7 +192,7 @@ class TraceBatchManager:
|
||||
|
||||
return finalized_batch
|
||||
|
||||
def _finalize_backend_batch(self):
|
||||
def _finalize_backend_batch(self, ephemeral: bool = True):
|
||||
"""Send batch finalization to backend"""
|
||||
if not self.plus_api or not self.trace_batch_id:
|
||||
return
|
||||
@@ -182,12 +206,24 @@ class TraceBatchManager:
|
||||
"final_event_count": total_events,
|
||||
}
|
||||
|
||||
response = self.plus_api.finalize_trace_batch(self.trace_batch_id, payload)
|
||||
response = (
|
||||
self.plus_api.finalize_ephemeral_trace_batch(
|
||||
self.trace_batch_id, payload
|
||||
)
|
||||
if ephemeral
|
||||
else self.plus_api.finalize_trace_batch(self.trace_batch_id, payload)
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
access_code = response.json().get("access_code", None)
|
||||
console = Console()
|
||||
return_link = (
|
||||
f"{CREWAI_BASE_URL}/crewai_plus/trace_batches/{self.trace_batch_id}"
|
||||
if not ephemeral and access_code
|
||||
else f"{CREWAI_BASE_URL}/crewai_plus/ephemeral_trace_batches/{self.trace_batch_id}?access_code={access_code}"
|
||||
)
|
||||
panel = Panel(
|
||||
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}. View here: {CREWAI_BASE_URL}/crewai_plus/trace_batches/{self.trace_batch_id}",
|
||||
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}. View here: {return_link} {f', Access Code: {access_code}' if access_code else ''}",
|
||||
title="Trace Batch Finalization",
|
||||
border_style="green",
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ from crewai.utilities.events.agent_events import (
|
||||
AgentExecutionErrorEvent,
|
||||
)
|
||||
from crewai.utilities.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.utilities.events.listeners.tracing.utils import is_tracing_enabled
|
||||
from crewai.utilities.events.reasoning_events import (
|
||||
AgentReasoningStartedEvent,
|
||||
AgentReasoningCompletedEvent,
|
||||
@@ -67,7 +66,7 @@ from crewai.utilities.events.memory_events import (
|
||||
MemorySaveFailedEvent,
|
||||
)
|
||||
|
||||
from crewai.cli.authentication.token import get_auth_token
|
||||
from crewai.cli.authentication.token import AuthError, get_auth_token
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
@@ -76,13 +75,12 @@ class TraceCollectionListener(BaseEventListener):
|
||||
Trace collection listener that orchestrates trace collection
|
||||
"""
|
||||
|
||||
trace_enabled: Optional[bool] = False
|
||||
complex_events = ["task_started", "llm_call_started", "llm_call_completed"]
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls, batch_manager=None, tracing: Optional[bool] = False):
|
||||
def __new__(cls, batch_manager=None):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
@@ -90,25 +88,22 @@ class TraceCollectionListener(BaseEventListener):
|
||||
def __init__(
|
||||
self,
|
||||
batch_manager: Optional[TraceBatchManager] = None,
|
||||
tracing: Optional[bool] = False,
|
||||
):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
super().__init__()
|
||||
self.batch_manager = batch_manager or TraceBatchManager()
|
||||
self.tracing = tracing or False
|
||||
self.trace_enabled = self._check_trace_enabled()
|
||||
self._initialized = True
|
||||
|
||||
def _check_trace_enabled(self) -> bool:
|
||||
def _check_authenticated(self) -> bool:
|
||||
"""Check if tracing should be enabled"""
|
||||
auth_token = get_auth_token()
|
||||
if not auth_token:
|
||||
try:
|
||||
res = bool(get_auth_token())
|
||||
return res
|
||||
except AuthError:
|
||||
return False
|
||||
|
||||
return is_tracing_enabled() or self.tracing
|
||||
|
||||
def _get_user_context(self) -> Dict[str, str]:
|
||||
"""Extract user context for tracing"""
|
||||
return {
|
||||
@@ -120,8 +115,6 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
def setup_listeners(self, crewai_event_bus):
|
||||
"""Setup event listeners - delegates to specific handlers"""
|
||||
if not self.trace_enabled:
|
||||
return
|
||||
|
||||
self._register_flow_event_handlers(crewai_event_bus)
|
||||
self._register_context_event_handlers(crewai_event_bus)
|
||||
@@ -167,13 +160,13 @@ class TraceCollectionListener(BaseEventListener):
|
||||
@event_bus.on(CrewKickoffStartedEvent)
|
||||
def on_crew_started(source, event):
|
||||
if not self.batch_manager.is_batch_initialized():
|
||||
self._initialize_batch(source, event)
|
||||
self._initialize_crew_batch(source, event)
|
||||
self._handle_trace_event("crew_kickoff_started", source, event)
|
||||
|
||||
@event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_completed(source, event):
|
||||
self._handle_trace_event("crew_kickoff_completed", source, event)
|
||||
self.batch_manager.finalize_batch()
|
||||
self.batch_manager.finalize_batch(ephemeral=True)
|
||||
|
||||
@event_bus.on(CrewKickoffFailedEvent)
|
||||
def on_crew_failed(source, event):
|
||||
@@ -287,7 +280,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
def on_agent_reasoning_failed(source, event):
|
||||
self._handle_action_event("agent_reasoning_failed", source, event)
|
||||
|
||||
def _initialize_batch(self, source: Any, event: Any):
|
||||
def _initialize_crew_batch(self, source: Any, event: Any):
|
||||
"""Initialize trace batch"""
|
||||
user_context = self._get_user_context()
|
||||
execution_metadata = {
|
||||
@@ -296,7 +289,7 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"crewai_version": get_crewai_version(),
|
||||
}
|
||||
|
||||
self.batch_manager.initialize_batch(user_context, execution_metadata)
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _initialize_flow_batch(self, source: Any, event: Any):
|
||||
"""Initialize trace batch for Flow execution"""
|
||||
@@ -308,7 +301,20 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"execution_type": "flow",
|
||||
}
|
||||
|
||||
self.batch_manager.initialize_batch(user_context, execution_metadata)
|
||||
self._initialize_batch(user_context, execution_metadata)
|
||||
|
||||
def _initialize_batch(
|
||||
self, user_context: Dict[str, str], execution_metadata: Dict[str, Any]
|
||||
):
|
||||
"""Initialize trace batch if ephemeral"""
|
||||
if not self._check_authenticated():
|
||||
self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=True
|
||||
)
|
||||
else:
|
||||
self.batch_manager.initialize_batch(
|
||||
user_context, execution_metadata, use_ephemeral=False
|
||||
)
|
||||
|
||||
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
|
||||
"""Generic handler for context end events"""
|
||||
|
||||
@@ -1,5 +1,153 @@
|
||||
import os
|
||||
import platform
|
||||
import uuid
|
||||
import hashlib
|
||||
import subprocess
|
||||
import getpass
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import re
|
||||
import json
|
||||
|
||||
import click
|
||||
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
return os.getenv("CREWAI_TRACING_ENABLED", "false").lower() == "true"
|
||||
|
||||
|
||||
def on_first_execution_tracing_confirmation() -> bool:
|
||||
if _is_test_environment():
|
||||
return False
|
||||
|
||||
if is_first_execution():
|
||||
mark_first_execution_done()
|
||||
return click.confirm(
|
||||
"This is the first execution of CrewAI. Do you want to enable tracing?",
|
||||
default=True,
|
||||
show_default=True,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _is_test_environment() -> bool:
|
||||
"""Detect if we're running in a test environment."""
|
||||
return os.environ.get("CREWAI_TESTING", "").lower() == "true"
|
||||
|
||||
|
||||
def _get_machine_id() -> str:
|
||||
"""Stable, privacy-preserving machine fingerprint (cross-platform)."""
|
||||
parts = []
|
||||
|
||||
try:
|
||||
mac = ":".join(
|
||||
["{:02x}".format((uuid.getnode() >> b) & 0xFF) for b in range(0, 12, 2)][
|
||||
::-1
|
||||
]
|
||||
)
|
||||
parts.append(mac)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
sysname = platform.system()
|
||||
parts.append(sysname)
|
||||
|
||||
try:
|
||||
if sysname == "Darwin":
|
||||
res = subprocess.run(
|
||||
["system_profiler", "SPHardwareDataType"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
m = re.search(r"Hardware UUID:\s*([A-Fa-f0-9\-]+)", res.stdout)
|
||||
if m:
|
||||
parts.append(m.group(1))
|
||||
elif sysname == "Linux":
|
||||
try:
|
||||
parts.append(Path("/etc/machine-id").read_text().strip())
|
||||
except Exception:
|
||||
parts.append(Path("/sys/class/dmi/id/product_uuid").read_text().strip())
|
||||
elif sysname == "Windows":
|
||||
res = subprocess.run(
|
||||
["wmic", "csproduct", "get", "UUID"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=2,
|
||||
)
|
||||
lines = [line.strip() for line in res.stdout.splitlines() if line.strip()]
|
||||
if len(lines) >= 2:
|
||||
parts.append(lines[1])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return hashlib.sha256("".join(parts).encode()).hexdigest()
|
||||
|
||||
|
||||
def _user_data_file() -> Path:
|
||||
base = Path(db_storage_path())
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
return base / ".crewai_user.json"
|
||||
|
||||
|
||||
def _load_user_data() -> dict:
|
||||
p = _user_data_file()
|
||||
if p.exists():
|
||||
try:
|
||||
return json.loads(p.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save_user_data(data: dict) -> None:
|
||||
try:
|
||||
p = _user_data_file()
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_user_id() -> str:
|
||||
"""Stable, anonymized user identifier with caching."""
|
||||
data = _load_user_data()
|
||||
|
||||
if "user_id" in data:
|
||||
return data["user_id"]
|
||||
|
||||
try:
|
||||
username = getpass.getuser()
|
||||
except Exception:
|
||||
username = "unknown"
|
||||
|
||||
seed = f"{username}|{_get_machine_id()}"
|
||||
uid = hashlib.sha256(seed.encode()).hexdigest()
|
||||
|
||||
data["user_id"] = uid
|
||||
_save_user_data(data)
|
||||
return uid
|
||||
|
||||
|
||||
def is_first_execution() -> bool:
|
||||
"""True if this is the first execution for this user."""
|
||||
data = _load_user_data()
|
||||
return not data.get("first_execution_done", False)
|
||||
|
||||
|
||||
def mark_first_execution_done() -> None:
|
||||
"""Mark that the first execution has been completed."""
|
||||
data = _load_user_data()
|
||||
if data.get("first_execution_done", False):
|
||||
return
|
||||
|
||||
data.update(
|
||||
{
|
||||
"first_execution_done": True,
|
||||
"first_execution_at": datetime.now().timestamp(),
|
||||
"user_id": get_user_id(),
|
||||
"machine_id": _get_machine_id(),
|
||||
}
|
||||
)
|
||||
_save_user_data(data)
|
||||
|
||||
@@ -1756,6 +1756,35 @@ def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_knowledge_search_with_context_parameter():
|
||||
"""Test that agent knowledge search accepts context parameter."""
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
|
||||
agent = Agent(
|
||||
role="Information Agent",
|
||||
goal="Provide information based on knowledge sources",
|
||||
backstory="You have access to specific knowledge sources.",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
knowledge_sources=[string_source],
|
||||
)
|
||||
|
||||
task_prompt = "What is Brandon's favorite color?"
|
||||
context = "Previous conversation mentioned Brandon's preferences."
|
||||
|
||||
with patch.object(agent.llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Brandon likes red"
|
||||
result = agent._get_knowledge_search_query(task_prompt, context)
|
||||
|
||||
assert result == "Brandon likes red"
|
||||
mock_call.assert_called_once()
|
||||
|
||||
call_args = mock_call.call_args[0][0]
|
||||
user_message = call_args[1]['content']
|
||||
assert context in user_message
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_with_knowledge_sources_with_query_limit_and_score_threshold_default():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -34,11 +34,11 @@ def setup_test_environment():
|
||||
f"Test storage directory {storage_dir} is not writable: {e}"
|
||||
)
|
||||
|
||||
# Set environment variable to point to the test storage directory
|
||||
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
|
||||
|
||||
os.environ["CREWAI_TESTING"] = "true"
|
||||
yield
|
||||
|
||||
os.environ.pop("CREWAI_TESTING", None)
|
||||
# Cleanup is handled automatically when tempfile context exits
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from concurrent.futures import Future
|
||||
@@ -51,7 +50,7 @@ from crewai.utilities.events.memory_events import (
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
)
|
||||
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
|
||||
@pytest.fixture
|
||||
def ceo():
|
||||
@@ -312,7 +311,6 @@ def test_crew_creation(researcher, writer):
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_sync_task_execution(researcher, writer):
|
||||
from unittest.mock import patch
|
||||
|
||||
tasks = [
|
||||
Task(
|
||||
@@ -961,7 +959,6 @@ def test_cache_hitting_between_agents(researcher, writer, ceo):
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_api_calls_throttling(capsys):
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.tools import tool
|
||||
|
||||
@@ -1396,7 +1393,6 @@ def test_kickoff_for_each_invalid_input():
|
||||
|
||||
def test_kickoff_for_each_error_handling():
|
||||
"""Tests error handling in kickoff_for_each when kickoff raises an error."""
|
||||
from unittest.mock import patch
|
||||
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
@@ -1433,7 +1429,6 @@ def test_kickoff_for_each_error_handling():
|
||||
@pytest.mark.asyncio
|
||||
async def test_kickoff_async_basic_functionality_and_output():
|
||||
"""Tests the basic functionality and output of kickoff_async."""
|
||||
from unittest.mock import patch
|
||||
|
||||
inputs = {"topic": "dog"}
|
||||
|
||||
@@ -1540,7 +1535,6 @@ async def test_async_kickoff_for_each_async_empty_input():
|
||||
|
||||
|
||||
def test_set_agents_step_callback():
|
||||
from unittest.mock import patch
|
||||
|
||||
researcher_agent = Agent(
|
||||
role="Researcher",
|
||||
@@ -1570,7 +1564,6 @@ def test_set_agents_step_callback():
|
||||
|
||||
|
||||
def test_dont_set_agents_step_callback_if_already_set():
|
||||
from unittest.mock import patch
|
||||
|
||||
def agent_callback(_):
|
||||
pass
|
||||
@@ -2035,7 +2028,6 @@ def test_crew_inputs_interpolate_both_agents_and_tasks():
|
||||
|
||||
|
||||
def test_crew_inputs_interpolate_both_agents_and_tasks_diff():
|
||||
from unittest.mock import patch
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
@@ -2068,7 +2060,6 @@ def test_crew_inputs_interpolate_both_agents_and_tasks_diff():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_does_not_interpolate_without_inputs():
|
||||
from unittest.mock import patch
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
@@ -2203,7 +2194,6 @@ def test_task_same_callback_both_on_task_and_crew():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_tools_with_custom_caching():
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.tools import tool
|
||||
|
||||
@@ -2484,7 +2474,6 @@ def test_multiple_conditional_tasks(researcher, writer):
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_using_contextual_memory():
|
||||
from unittest.mock import patch
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
@@ -2583,7 +2572,6 @@ def test_memory_events_are_emitted():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_using_contextual_memory_with_long_term_memory():
|
||||
from unittest.mock import patch
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
@@ -2614,7 +2602,6 @@ def test_using_contextual_memory_with_long_term_memory():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_warning_long_term_memory_without_entity_memory():
|
||||
from unittest.mock import patch
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
@@ -2651,7 +2638,6 @@ def test_warning_long_term_memory_without_entity_memory():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_long_term_memory_with_memory_flag():
|
||||
from unittest.mock import patch
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
@@ -2686,7 +2672,6 @@ def test_long_term_memory_with_memory_flag():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_using_contextual_memory_with_short_term_memory():
|
||||
from unittest.mock import patch
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
@@ -2717,7 +2702,6 @@ def test_using_contextual_memory_with_short_term_memory():
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_disabled_memory_using_contextual_memory():
|
||||
from unittest.mock import patch
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
@@ -2845,7 +2829,6 @@ def test_crew_output_file_validation_failures():
|
||||
|
||||
|
||||
def test_manager_agent(researcher, writer):
|
||||
from unittest.mock import patch
|
||||
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
@@ -4752,3 +4735,43 @@ def test_default_crew_name(researcher, writer):
|
||||
],
|
||||
)
|
||||
assert crew.name == "crew"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_ensure_exchanged_messages_are_propagated_to_external_memory():
|
||||
external_memory = ExternalMemory(storage=MagicMock())
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
backstory="You're an expert in research and you love to learn new things.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="Research a topic to teach a kid aged 6 about math.",
|
||||
expected_output="A topic, explanation, angle, and examples.",
|
||||
agent=math_researcher,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
external_memory=external_memory,
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
ExternalMemory, "save", return_value=None
|
||||
) as external_memory_save:
|
||||
crew.kickoff()
|
||||
|
||||
expected_messages = [
|
||||
{'role': 'system', 'content': "You are Researcher. You're an expert in research and you love to learn new things.\nYour personal goal is: You research about math.\nTo give my best complete final answer to the task respond using the exact following format:\n\nThought: I now can give a great answer\nFinal Answer: Your final answer must be the great and the most complete as possible, it must be outcome described.\n\nI MUST use these formats, my job depends on it!"},
|
||||
{'role': 'user', 'content': '\nCurrent Task: Research a topic to teach a kid aged 6 about math.\n\nThis is the expected criteria for your final answer: A topic, explanation, angle, and examples.\nyou MUST return the actual complete content as the final answer, not a summary.\n\n# Useful context: \nExternal memories:\n\n\nBegin! This is VERY important to you, use the tools available and give your best Final Answer, your job depends on it!\n\nThought:'},
|
||||
{'role': 'assistant', 'content': 'I now can give a great answer \nFinal Answer: \n\n**Topic: Understanding Shapes (Geometry)**\n\n**Explanation:** \nShapes are everywhere around us! They are the special forms that we can see in everyday objects. Teaching a 6-year-old about shapes is not only fun but also a way to help them think about the world around them and develop their spatial awareness. We will focus on basic shapes: circle, square, triangle, and rectangle. Understanding these shapes helps kids recognize and describe their environment.\n\n**Angle:** \nLet’s make learning about shapes an adventure! We can turn it into a treasure hunt where the child has to find objects around the house or outside that match the shapes we learn. This hands-on approach helps make the learning stick!\n\n**Examples:** \n1. **Circle:** \n - Explanation: A circle is round and has no corners. It looks like a wheel or a cookie! \n - Activity: Find objects that are circles, such as a clock, a dinner plate, or a ball. Draw a big circle on a paper and then try to draw smaller circles inside it.\n\n2. **Square:** \n - Explanation: A square has four equal sides and four corners. It looks like a box! \n - Activity: Look for squares in books, in windows, or in building blocks. Try to build a tall tower using square blocks!\n\n3. **Triangle:** \n - Explanation: A triangle has three sides and three corners. It looks like a slice of pizza or a roof! \n - Activity: Use crayons to draw a big triangle and then find things that are shaped like a triangle, like a slice of cheese or a traffic sign.\n\n4. **Rectangle:** \n - Explanation: A rectangle has four sides but only opposite sides are equal. It’s like a stretched square! \n - Activity: Search for rectangles, such as a book cover or a door. You can cut out rectangles from colored paper and create a collage!\n\nBy relating the shapes to fun activities and using real-world examples, we not only make learning more enjoyable but also help the child better remember and understand the concept of shapes in math. This foundation forms the basis of their future learning in geometry!'}
|
||||
]
|
||||
external_memory_save.assert_called_once_with(
|
||||
value=ANY,
|
||||
metadata={"description": ANY, "messages": expected_messages},
|
||||
agent=ANY,
|
||||
)
|
||||
|
||||
223
tests/test_context_aware_knowledge_search.py
Normal file
223
tests/test_context_aware_knowledge_search.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Test context-aware knowledge search functionality."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai import Agent, Task, Crew, LLM
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_knowledge_search_with_context():
|
||||
"""Test that knowledge search includes context from previous tasks."""
|
||||
content = "The company's main product is a CRM system. The system has three modules: Sales, Marketing, and Support."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
|
||||
researcher = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Research company information",
|
||||
backstory="You are a research analyst.",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
knowledge_sources=[string_source],
|
||||
)
|
||||
|
||||
writer = Agent(
|
||||
role="Content Writer",
|
||||
goal="Write content based on research",
|
||||
backstory="You are a content writer.",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
knowledge_sources=[string_source],
|
||||
)
|
||||
|
||||
research_task = Task(
|
||||
description="Research the company's main product",
|
||||
expected_output="A summary of the company's main product",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
writing_task = Task(
|
||||
description="Write a detailed description of the CRM modules",
|
||||
expected_output="A detailed description of each CRM module",
|
||||
agent=writer,
|
||||
context=[research_task],
|
||||
)
|
||||
|
||||
crew = Crew(agents=[researcher, writer], tasks=[research_task, writing_task])
|
||||
|
||||
with patch.object(writer, '_get_knowledge_search_query') as mock_search:
|
||||
mock_search.return_value = "mocked query"
|
||||
crew.kickoff()
|
||||
|
||||
mock_search.assert_called_once()
|
||||
call_args = mock_search.call_args
|
||||
assert len(call_args[0]) == 2
|
||||
assert call_args[0][1] is not None
|
||||
assert "CRM system" in call_args[0][1] or "product" in call_args[0][1]
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_knowledge_search_without_context():
|
||||
"""Test that knowledge search works without context (backward compatibility)."""
|
||||
content = "The company's main product is a CRM system."
|
||||
string_source = StringKnowledgeSource(content=content)
|
||||
|
||||
agent = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Research company information",
|
||||
backstory="You are a research analyst.",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
knowledge_sources=[string_source],
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Research the company's main product",
|
||||
expected_output="A summary of the company's main product",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
with patch.object(agent, '_get_knowledge_search_query') as mock_search:
|
||||
mock_search.return_value = "mocked query"
|
||||
crew.kickoff()
|
||||
|
||||
mock_search.assert_called_once()
|
||||
call_args = mock_search.call_args
|
||||
assert len(call_args[0]) == 2
|
||||
assert call_args[0][1] == ""
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_context_aware_knowledge_search_integration():
|
||||
"""Integration test for context-aware knowledge search."""
|
||||
knowledge_content = """
|
||||
Project Alpha is a web application built with React and Node.js.
|
||||
Project Beta is a mobile application built with React Native.
|
||||
The team uses Agile methodology with 2-week sprints.
|
||||
The database is PostgreSQL with Redis for caching.
|
||||
"""
|
||||
|
||||
string_source = StringKnowledgeSource(content=knowledge_content)
|
||||
|
||||
project_manager = Agent(
|
||||
role="Project Manager",
|
||||
goal="Gather project information",
|
||||
backstory="You manage software projects.",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
knowledge_sources=[string_source],
|
||||
)
|
||||
|
||||
tech_lead = Agent(
|
||||
role="Technical Lead",
|
||||
goal="Provide technical details",
|
||||
backstory="You are a technical expert.",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
knowledge_sources=[string_source],
|
||||
)
|
||||
|
||||
project_overview_task = Task(
|
||||
description="Provide an overview of Project Alpha",
|
||||
expected_output="Overview of Project Alpha including its technology stack",
|
||||
agent=project_manager,
|
||||
)
|
||||
|
||||
technical_details_task = Task(
|
||||
description="Provide technical implementation details for the project",
|
||||
expected_output="Technical implementation details including database and caching",
|
||||
agent=tech_lead,
|
||||
context=[project_overview_task],
|
||||
)
|
||||
|
||||
crew = Crew(agents=[project_manager, tech_lead], tasks=[project_overview_task, technical_details_task])
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
assert result.raw is not None
|
||||
assert any(keyword in result.raw.lower() for keyword in ["react", "node", "postgresql", "redis"])
|
||||
|
||||
|
||||
def test_knowledge_search_query_template_with_context():
|
||||
"""Test that the knowledge search query template includes context properly."""
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test knowledge search",
|
||||
backstory="Test agent",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
)
|
||||
|
||||
task_prompt = "What is the main product?"
|
||||
context = "Previous research shows the company focuses on CRM solutions."
|
||||
|
||||
with patch.object(agent.llm, 'call') as mock_call:
|
||||
mock_call.return_value = "mocked response"
|
||||
|
||||
agent._get_knowledge_search_query(task_prompt, context)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
call_args = mock_call.call_args[0][0]
|
||||
user_message = call_args[1]['content']
|
||||
|
||||
assert task_prompt in user_message
|
||||
assert context in user_message
|
||||
assert "Context from previous tasks:" in user_message
|
||||
|
||||
|
||||
def test_knowledge_search_query_template_without_context():
|
||||
"""Test that the knowledge search query template works without context."""
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test knowledge search",
|
||||
backstory="Test agent",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
)
|
||||
|
||||
task_prompt = "What is the main product?"
|
||||
|
||||
with patch.object(agent.llm, 'call') as mock_call:
|
||||
mock_call.return_value = "mocked response"
|
||||
|
||||
agent._get_knowledge_search_query(task_prompt)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
call_args = mock_call.call_args[0][0]
|
||||
user_message = call_args[1]['content']
|
||||
|
||||
assert task_prompt in user_message
|
||||
assert "Context from previous tasks:" not in user_message
|
||||
|
||||
|
||||
def test_structured_context_integration():
|
||||
"""Test context-aware knowledge search with structured context data."""
|
||||
knowledge_content = """
|
||||
Error URS-01: User registration service unavailable.
|
||||
Method getUserStatus returns user account status.
|
||||
API endpoint /api/users/{id}/status for user status queries.
|
||||
Database table user_accounts stores user information.
|
||||
"""
|
||||
|
||||
string_source = StringKnowledgeSource(content=knowledge_content)
|
||||
|
||||
agent = Agent(
|
||||
role="Technical Support",
|
||||
goal="Resolve technical issues",
|
||||
backstory="You help resolve technical problems.",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
knowledge_sources=[string_source],
|
||||
)
|
||||
|
||||
task_prompt = "How to resolve the user status error?"
|
||||
structured_context = '{"method": "getUserStatus", "error_code": "URS-01", "endpoint": "/api/users/{id}/status"}'
|
||||
|
||||
with patch.object(agent.llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Check getUserStatus method and URS-01 error"
|
||||
|
||||
agent._get_knowledge_search_query(task_prompt, structured_context)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
call_args = mock_call.call_args[0][0]
|
||||
user_message = call_args[1]['content']
|
||||
|
||||
assert task_prompt in user_message
|
||||
assert "getUserStatus" in user_message
|
||||
assert "URS-01" in user_message
|
||||
assert "Context from previous tasks:" in user_message
|
||||
@@ -2,8 +2,8 @@ import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Remove the module-level patch
|
||||
from crewai import Agent, Task, Crew
|
||||
from crewai.flow.flow import Flow, start
|
||||
from crewai.utilities.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
@@ -284,29 +284,42 @@ class TestTraceListenerSetup:
|
||||
f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
|
||||
)
|
||||
|
||||
def test_trace_listener_setup_correctly(self):
|
||||
def test_trace_listener_setup_correctly_for_crew(self):
|
||||
"""Test that trace listener is set up correctly when enabled"""
|
||||
|
||||
with patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "true"}):
|
||||
trace_listener = TraceCollectionListener()
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
task = Task(
|
||||
description="Say hello to the world",
|
||||
expected_output="hello world",
|
||||
agent=agent,
|
||||
)
|
||||
with patch.object(
|
||||
TraceCollectionListener, "setup_listeners"
|
||||
) as mock_listener_setup:
|
||||
Crew(agents=[agent], tasks=[task], verbose=True)
|
||||
assert mock_listener_setup.call_count >= 1
|
||||
|
||||
assert trace_listener.trace_enabled is True
|
||||
assert trace_listener.batch_manager is not None
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_trace_listener_setup_correctly_with_tracing_flag(self):
|
||||
def test_trace_listener_setup_correctly_for_flow(self):
|
||||
"""Test that trace listener is set up correctly when enabled"""
|
||||
agent = Agent(role="Test Agent", goal="Test goal", backstory="Test backstory")
|
||||
task = Task(
|
||||
description="Say hello to the world",
|
||||
expected_output="hello world",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=True, tracing=True)
|
||||
crew.kickoff()
|
||||
trace_listener = TraceCollectionListener(tracing=True)
|
||||
assert trace_listener.trace_enabled is True
|
||||
assert trace_listener.batch_manager is not None
|
||||
|
||||
with patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "true"}):
|
||||
|
||||
class FlowExample(Flow):
|
||||
@start()
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
with patch.object(
|
||||
TraceCollectionListener, "setup_listeners"
|
||||
) as mock_listener_setup:
|
||||
FlowExample()
|
||||
assert mock_listener_setup.call_count >= 1
|
||||
|
||||
# Helper method to ensure cleanup
|
||||
def teardown_method(self):
|
||||
|
||||
Reference in New Issue
Block a user