mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
fix: Update embedding configuration and fix type errors
- Add configurable embedding providers (OpenAI, Ollama) - Fix type hints in base_tool and structured_tool - Add proper json property implementations - Update documentation for memory configuration - Add environment variables for embedding configuration - Fix type errors in task and crew output classes Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
45
docs/memory.md
Normal file
45
docs/memory.md
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# Memory in CrewAI
|
||||||
|
|
||||||
|
CrewAI provides a robust memory system that allows agents to retain and recall information from previous interactions.
|
||||||
|
|
||||||
|
## Configuring Embedding Providers
|
||||||
|
|
||||||
|
CrewAI supports multiple embedding providers for memory functionality:
|
||||||
|
|
||||||
|
- OpenAI (default) - Requires `OPENAI_API_KEY`
|
||||||
|
- Ollama - Requires `CREWAI_OLLAMA_URL` (defaults to "http://localhost:11434/api/embeddings")
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
Configure the embedding provider using these environment variables:
|
||||||
|
|
||||||
|
- `CREWAI_EMBEDDING_PROVIDER`: Provider name (default: "openai")
|
||||||
|
- `CREWAI_EMBEDDING_MODEL`: Model name (default: "text-embedding-3-small")
|
||||||
|
- `CREWAI_OLLAMA_URL`: URL for Ollama API (when using Ollama provider)
|
||||||
|
|
||||||
|
### Example Configuration
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Using OpenAI (default)
|
||||||
|
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
||||||
|
|
||||||
|
# Using Ollama
|
||||||
|
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
|
||||||
|
os.environ["CREWAI_EMBEDDING_MODEL"] = "llama2" # or any other model supported by your Ollama instance
|
||||||
|
os.environ["CREWAI_OLLAMA_URL"] = "http://localhost:11434/api/embeddings" # optional, this is the default
|
||||||
|
```
|
||||||
|
|
||||||
|
## Memory Usage
|
||||||
|
|
||||||
|
When an agent has memory enabled, it can access and store information from previous interactions:
|
||||||
|
|
||||||
|
```python
|
||||||
|
agent = Agent(
|
||||||
|
role="Researcher",
|
||||||
|
goal="Research AI topics",
|
||||||
|
backstory="You're an AI researcher",
|
||||||
|
memory=True # Enable memory for this agent
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The memory system uses embeddings to store and retrieve relevant information, allowing agents to maintain context across multiple interactions and tasks.
|
||||||
@@ -243,6 +243,15 @@ class Agent(BaseAgent):
|
|||||||
if isinstance(self.knowledge_sources, list) and all(
|
if isinstance(self.knowledge_sources, list) and all(
|
||||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||||
):
|
):
|
||||||
|
# Validate embedding configuration based on provider
|
||||||
|
from crewai.utilities.constants import DEFAULT_EMBEDDING_PROVIDER
|
||||||
|
provider = os.getenv("CREWAI_EMBEDDING_PROVIDER", DEFAULT_EMBEDDING_PROVIDER)
|
||||||
|
|
||||||
|
if provider == "openai" and not os.getenv("OPENAI_API_KEY"):
|
||||||
|
raise ValueError("Please provide an OpenAI API key via OPENAI_API_KEY environment variable")
|
||||||
|
elif provider == "ollama" and not os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings"):
|
||||||
|
raise ValueError("Please provide Ollama URL via CREWAI_OLLAMA_URL environment variable")
|
||||||
|
|
||||||
self._knowledge = Knowledge(
|
self._knowledge = Knowledge(
|
||||||
sources=self.knowledge_sources,
|
sources=self.knowledge_sources,
|
||||||
embedder_config=self.embedder_config,
|
embedder_config=self.embedder_config,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import uuid
|
|||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
@@ -797,7 +797,7 @@ class Crew(BaseModel):
|
|||||||
return skipped_task_output
|
return skipped_task_output
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _prepare_tools(self, agent: BaseAgent, task: Task, tools: List[Tool]) -> List[Tool]:
|
def _prepare_tools(self, agent: BaseAgent, task: Task, tools: Sequence[Tool]) -> List[Tool]:
|
||||||
# Add delegation tools if agent allows delegation
|
# Add delegation tools if agent allows delegation
|
||||||
if agent.allow_delegation:
|
if agent.allow_delegation:
|
||||||
if self.process == Process.hierarchical:
|
if self.process == Process.hierarchical:
|
||||||
@@ -823,7 +823,7 @@ class Crew(BaseModel):
|
|||||||
return self.manager_agent
|
return self.manager_agent
|
||||||
return task.agent
|
return task.agent
|
||||||
|
|
||||||
def _merge_tools(self, existing_tools: List[Tool], new_tools: List[Tool]) -> List[Tool]:
|
def _merge_tools(self, existing_tools: Sequence[Tool], new_tools: Sequence[Tool]) -> List[Tool]:
|
||||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||||
if not new_tools:
|
if not new_tools:
|
||||||
return existing_tools
|
return existing_tools
|
||||||
@@ -839,19 +839,19 @@ class Crew(BaseModel):
|
|||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
def _inject_delegation_tools(self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent]):
|
def _inject_delegation_tools(self, tools: Sequence[Tool], task_agent: BaseAgent, agents: Sequence[BaseAgent]):
|
||||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||||
return self._merge_tools(tools, delegation_tools)
|
return self._merge_tools(tools, delegation_tools)
|
||||||
|
|
||||||
def _add_multimodal_tools(self, agent: BaseAgent, tools: List[Tool]):
|
def _add_multimodal_tools(self, agent: BaseAgent, tools: Sequence[Tool]):
|
||||||
multimodal_tools = agent.get_multimodal_tools()
|
multimodal_tools = agent.get_multimodal_tools()
|
||||||
return self._merge_tools(tools, multimodal_tools)
|
return self._merge_tools(tools, multimodal_tools)
|
||||||
|
|
||||||
def _add_code_execution_tools(self, agent: BaseAgent, tools: List[Tool]):
|
def _add_code_execution_tools(self, agent: BaseAgent, tools: Sequence[Tool]):
|
||||||
code_tools = agent.get_code_execution_tools()
|
code_tools = agent.get_code_execution_tools()
|
||||||
return self._merge_tools(tools, code_tools)
|
return self._merge_tools(tools, code_tools)
|
||||||
|
|
||||||
def _add_delegation_tools(self, task: Task, tools: List[Tool]):
|
def _add_delegation_tools(self, task: Task, tools: Sequence[Tool]):
|
||||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||||
if not tools:
|
if not tools:
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Set, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from crewai.tasks.output_format import OutputFormat
|
from crewai.tasks.output_format import OutputFormat
|
||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
from crewai.types.usage_metrics import UsageMetrics
|
from crewai.types.usage_metrics import UsageMetrics
|
||||||
|
|
||||||
|
# Type definition for include/exclude parameters
|
||||||
|
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class CrewOutput(BaseModel):
|
class CrewOutput(BaseModel):
|
||||||
"""Class that represents the result of a crew."""
|
"""Class that represents the result of a crew."""
|
||||||
@@ -24,13 +28,41 @@ class CrewOutput(BaseModel):
|
|||||||
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json(self) -> Optional[str]:
|
def json(self) -> str:
|
||||||
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
"""Get the JSON representation of the output."""
|
||||||
|
if self.tasks_output and self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||||
)
|
)
|
||||||
|
return json.dumps(self.json_dict) if self.json_dict else "{}"
|
||||||
|
|
||||||
return json.dumps(self.json_dict)
|
def model_dump_json(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
indent: Optional[int] = None,
|
||||||
|
include: Optional[IncEx] = None,
|
||||||
|
exclude: Optional[IncEx] = None,
|
||||||
|
context: Optional[Any] = None,
|
||||||
|
by_alias: bool = False,
|
||||||
|
exclude_unset: bool = False,
|
||||||
|
exclude_defaults: bool = False,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
round_trip: bool = False,
|
||||||
|
warnings: bool | Literal["none", "warn", "error"] = False,
|
||||||
|
serialize_as_any: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""Override model_dump_json to handle custom JSON output."""
|
||||||
|
return super().model_dump_json(
|
||||||
|
indent=indent,
|
||||||
|
include=include,
|
||||||
|
exclude=exclude,
|
||||||
|
by_alias=by_alias,
|
||||||
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_defaults=exclude_defaults,
|
||||||
|
exclude_none=exclude_none,
|
||||||
|
round_trip=round_trip,
|
||||||
|
warnings=warnings,
|
||||||
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Convert json_output and pydantic_output to a dictionary."""
|
"""Convert json_output and pydantic_output to a dictionary."""
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class FastEmbed(BaseEmbedder):
|
|||||||
cache_dir=str(cache_dir) if cache_dir else None,
|
cache_dir=str(cache_dir) if cache_dir else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
|
def embed_chunks(self, chunks: List[str]) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Generate embeddings for a list of text chunks
|
Generate embeddings for a list of text chunks
|
||||||
|
|
||||||
@@ -55,12 +55,12 @@ class FastEmbed(BaseEmbedder):
|
|||||||
chunks: List of text chunks to embed
|
chunks: List of text chunks to embed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings
|
Array of embeddings
|
||||||
"""
|
"""
|
||||||
embeddings = list(self.model.embed(chunks))
|
embeddings = list(self.model.embed(chunks))
|
||||||
return embeddings
|
return np.stack(embeddings)
|
||||||
|
|
||||||
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
|
def embed_texts(self, texts: List[str]) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Generate embeddings for a list of texts
|
Generate embeddings for a list of texts
|
||||||
|
|
||||||
@@ -68,10 +68,10 @@ class FastEmbed(BaseEmbedder):
|
|||||||
texts: List of texts to embed
|
texts: List of texts to embed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings
|
Array of embeddings
|
||||||
"""
|
"""
|
||||||
embeddings = list(self.model.embed(texts))
|
embeddings = list(self.model.embed(texts))
|
||||||
return embeddings
|
return np.stack(embeddings)
|
||||||
|
|
||||||
def embed_text(self, text: str) -> np.ndarray:
|
def embed_text(self, text: str) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -154,8 +154,12 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
filtered_ids.append(doc_id)
|
filtered_ids.append(doc_id)
|
||||||
|
|
||||||
# If we have no metadata at all, set it to None
|
# If we have no metadata at all, set it to None
|
||||||
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
final_metadata: Optional[List[Dict[str, Union[str, int, float, bool]]]] = (
|
||||||
None if all(m is None for m in filtered_metadata) else filtered_metadata
|
None if all(m is None for m in filtered_metadata) else [
|
||||||
|
{k: v for k, v in m.items() if isinstance(v, (str, int, float, bool))}
|
||||||
|
if m is not None else None
|
||||||
|
for m in filtered_metadata
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.collection.upsert(
|
self.collection.upsert(
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
@@ -250,7 +251,7 @@ class Task(BaseModel):
|
|||||||
self,
|
self,
|
||||||
agent: Optional[BaseAgent] = None,
|
agent: Optional[BaseAgent] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
tools: Optional[List[BaseTool]] = None,
|
tools: Optional[Sequence[BaseTool]] = None,
|
||||||
) -> TaskOutput:
|
) -> TaskOutput:
|
||||||
"""Execute the task synchronously."""
|
"""Execute the task synchronously."""
|
||||||
return self._execute_core(agent, context, tools)
|
return self._execute_core(agent, context, tools)
|
||||||
@@ -267,7 +268,7 @@ class Task(BaseModel):
|
|||||||
self,
|
self,
|
||||||
agent: BaseAgent | None = None,
|
agent: BaseAgent | None = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
tools: Optional[List[BaseTool]] = None,
|
tools: Optional[Sequence[BaseTool]] = None,
|
||||||
) -> Future[TaskOutput]:
|
) -> Future[TaskOutput]:
|
||||||
"""Execute the task asynchronously."""
|
"""Execute the task asynchronously."""
|
||||||
future: Future[TaskOutput] = Future()
|
future: Future[TaskOutput] = Future()
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Set, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from crewai.tasks.output_format import OutputFormat
|
from crewai.tasks.output_format import OutputFormat
|
||||||
|
|
||||||
|
# Type definition for include/exclude parameters
|
||||||
|
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class TaskOutput(BaseModel):
|
class TaskOutput(BaseModel):
|
||||||
"""Class that represents the result of a task."""
|
"""Class that represents the result of a task."""
|
||||||
@@ -35,7 +39,8 @@ class TaskOutput(BaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json(self) -> Optional[str]:
|
def json(self) -> str:
|
||||||
|
"""Get the JSON representation of the output."""
|
||||||
if self.output_format != OutputFormat.JSON:
|
if self.output_format != OutputFormat.JSON:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"""
|
"""
|
||||||
@@ -44,8 +49,35 @@ class TaskOutput(BaseModel):
|
|||||||
please make sure to set the output_json property for the task
|
please make sure to set the output_json property for the task
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
return json.dumps(self.json_dict) if self.json_dict else "{}"
|
||||||
|
|
||||||
return json.dumps(self.json_dict)
|
def model_dump_json(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
indent: Optional[int] = None,
|
||||||
|
include: Optional[IncEx] = None,
|
||||||
|
exclude: Optional[IncEx] = None,
|
||||||
|
context: Optional[Any] = None,
|
||||||
|
by_alias: bool = False,
|
||||||
|
exclude_unset: bool = False,
|
||||||
|
exclude_defaults: bool = False,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
round_trip: bool = False,
|
||||||
|
warnings: bool | Literal["none", "warn", "error"] = False,
|
||||||
|
serialize_as_any: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""Override model_dump_json to handle custom JSON output."""
|
||||||
|
return super().model_dump_json(
|
||||||
|
indent=indent,
|
||||||
|
include=include,
|
||||||
|
exclude=exclude,
|
||||||
|
by_alias=by_alias,
|
||||||
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_defaults=exclude_defaults,
|
||||||
|
exclude_none=exclude_none,
|
||||||
|
round_trip=round_trip,
|
||||||
|
warnings=warnings,
|
||||||
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Convert json_output and pydantic_output to a dictionary."""
|
"""Convert json_output and pydantic_output to a dictionary."""
|
||||||
|
|||||||
@@ -1,10 +1,15 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Callable, Type, get_args, get_origin
|
from typing import Any, Callable, Dict, Optional, Type, Tuple, get_args, get_origin
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, create_model, validator
|
from pydantic import BaseModel, ConfigDict, Field, create_model, validator
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic import BaseModel as PydanticBaseModel
|
from pydantic import BaseModel as PydanticBaseModel
|
||||||
|
|
||||||
|
def _create_model_fields(fields: Dict[str, Tuple[Any, FieldInfo]]) -> Dict[str, Any]:
|
||||||
|
"""Helper function to create model fields with proper type hints."""
|
||||||
|
return {name: (annotation, field) for name, (annotation, field) in fields.items()}
|
||||||
|
|
||||||
from crewai.tools.structured_tool import CrewStructuredTool
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
|
|
||||||
|
|
||||||
@@ -12,7 +17,8 @@ class BaseTool(BaseModel, ABC):
|
|||||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_config = ConfigDict()
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
func: Optional[Callable] = None
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
"""The unique name of the tool that clearly communicates its purpose."""
|
"""The unique name of the tool that clearly communicates its purpose."""
|
||||||
@@ -104,20 +110,22 @@ class BaseTool(BaseModel, ABC):
|
|||||||
description="",
|
description="",
|
||||||
)
|
)
|
||||||
args_fields[name] = (param_annotation, field_info)
|
args_fields[name] = (param_annotation, field_info)
|
||||||
|
schema_name = f"{tool.name}Input"
|
||||||
if args_fields:
|
if args_fields:
|
||||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
model_fields = _create_model_fields(args_fields)
|
||||||
|
args_schema = create_model(schema_name, __base__=PydanticBaseModel, **model_fields)
|
||||||
else:
|
else:
|
||||||
# Create a default schema with no fields if no parameters are found
|
# Create a default schema with no fields if no parameters are found
|
||||||
args_schema = create_model(
|
args_schema = create_model(schema_name, __base__=PydanticBaseModel)
|
||||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
tool_instance = cls(
|
||||||
name=getattr(tool, "name", "Unnamed Tool"),
|
name=getattr(tool, "name", "Unnamed Tool"),
|
||||||
description=getattr(tool, "description", ""),
|
description=getattr(tool, "description", ""),
|
||||||
func=tool.func,
|
|
||||||
args_schema=args_schema,
|
args_schema=args_schema,
|
||||||
)
|
)
|
||||||
|
if hasattr(tool, "func"):
|
||||||
|
tool_instance.func = tool.func
|
||||||
|
return tool_instance
|
||||||
|
|
||||||
def _set_args_schema(self):
|
def _set_args_schema(self):
|
||||||
if self.args_schema is None:
|
if self.args_schema is None:
|
||||||
@@ -171,6 +179,12 @@ class Tool(BaseTool):
|
|||||||
"""The function that will be executed when the tool is called."""
|
"""The function that will be executed when the tool is called."""
|
||||||
|
|
||||||
func: Callable
|
func: Callable
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
if "func" not in kwargs:
|
||||||
|
raise ValueError("Tool requires a 'func' argument")
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
@@ -212,20 +226,22 @@ class Tool(BaseTool):
|
|||||||
description="",
|
description="",
|
||||||
)
|
)
|
||||||
args_fields[name] = (param_annotation, field_info)
|
args_fields[name] = (param_annotation, field_info)
|
||||||
|
schema_name = f"{tool.name}Input"
|
||||||
if args_fields:
|
if args_fields:
|
||||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
model_fields = _create_model_fields(args_fields)
|
||||||
|
args_schema = create_model(schema_name, __base__=PydanticBaseModel, **model_fields)
|
||||||
else:
|
else:
|
||||||
# Create a default schema with no fields if no parameters are found
|
# Create a default schema with no fields if no parameters are found
|
||||||
args_schema = create_model(
|
args_schema = create_model(schema_name, __base__=PydanticBaseModel)
|
||||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
tool_instance = cls(
|
||||||
name=getattr(tool, "name", "Unnamed Tool"),
|
name=getattr(tool, "name", "Unnamed Tool"),
|
||||||
description=getattr(tool, "description", ""),
|
description=getattr(tool, "description", ""),
|
||||||
func=tool.func,
|
|
||||||
args_schema=args_schema,
|
args_schema=args_schema,
|
||||||
)
|
)
|
||||||
|
if hasattr(tool, "func"):
|
||||||
|
tool_instance.func = tool.func
|
||||||
|
return tool_instance
|
||||||
|
|
||||||
|
|
||||||
def to_langchain(
|
def to_langchain(
|
||||||
|
|||||||
@@ -2,9 +2,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, Callable, Optional, Union, get_type_hints
|
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_type_hints
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
|
def _create_model_fields(fields: Dict[str, Tuple[Any, FieldInfo]]) -> Dict[str, Any]:
|
||||||
|
"""Helper function to create model fields with proper type hints."""
|
||||||
|
return {name: (annotation, field) for name, (annotation, field) in fields.items()}
|
||||||
|
|
||||||
from crewai.utilities.logger import Logger
|
from crewai.utilities.logger import Logger
|
||||||
|
|
||||||
@@ -142,7 +147,8 @@ class CrewStructuredTool:
|
|||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
schema_name = f"{name.title()}Schema"
|
schema_name = f"{name.title()}Schema"
|
||||||
return create_model(schema_name, **fields)
|
model_fields = _create_model_fields(fields)
|
||||||
|
return create_model(schema_name, __base__=BaseModel, **model_fields)
|
||||||
|
|
||||||
def _validate_function_signature(self) -> None:
|
def _validate_function_signature(self) -> None:
|
||||||
"""Validate that the function signature matches the args schema."""
|
"""Validate that the function signature matches the args schema."""
|
||||||
|
|||||||
@@ -4,3 +4,7 @@ DEFAULT_SCORE_THRESHOLD = 0.35
|
|||||||
KNOWLEDGE_DIRECTORY = "knowledge"
|
KNOWLEDGE_DIRECTORY = "knowledge"
|
||||||
MAX_LLM_RETRY = 3
|
MAX_LLM_RETRY = 3
|
||||||
MAX_FILE_NAME_LENGTH = 255
|
MAX_FILE_NAME_LENGTH = 255
|
||||||
|
|
||||||
|
# Default embedding configuration
|
||||||
|
DEFAULT_EMBEDDING_PROVIDER = "openai"
|
||||||
|
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||||
|
|||||||
@@ -47,13 +47,22 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_default_embedding_function():
|
def _create_default_embedding_function():
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
from crewai.utilities.constants import DEFAULT_EMBEDDING_PROVIDER, DEFAULT_EMBEDDING_MODEL
|
||||||
OpenAIEmbeddingFunction,
|
provider = os.getenv("CREWAI_EMBEDDING_PROVIDER", DEFAULT_EMBEDDING_PROVIDER)
|
||||||
)
|
model = os.getenv("CREWAI_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
if provider == "ollama":
|
||||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
|
||||||
)
|
return OllamaEmbeddingFunction(
|
||||||
|
url=os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings"),
|
||||||
|
model_name=model
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
|
||||||
|
return OpenAIEmbeddingFunction(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
model_name=model
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_openai(config, model_name):
|
def _configure_openai(config, model_name):
|
||||||
|
|||||||
@@ -1,4 +1,30 @@
|
|||||||
# conftest.py
|
# conftest.py
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_result = load_dotenv(override=True)
|
load_result = load_dotenv(override=True)
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_test_env():
|
||||||
|
"""Configure test environment to use Ollama as the default embedding provider."""
|
||||||
|
# Store original environment variables
|
||||||
|
original_env = {
|
||||||
|
"CREWAI_EMBEDDING_PROVIDER": os.environ.get("CREWAI_EMBEDDING_PROVIDER"),
|
||||||
|
"CREWAI_EMBEDDING_MODEL": os.environ.get("CREWAI_EMBEDDING_MODEL"),
|
||||||
|
"CREWAI_OLLAMA_URL": os.environ.get("CREWAI_OLLAMA_URL"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set test environment
|
||||||
|
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
|
||||||
|
os.environ["CREWAI_EMBEDDING_MODEL"] = "llama2"
|
||||||
|
os.environ["CREWAI_OLLAMA_URL"] = "http://localhost:11434/api/embeddings"
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Restore original environment
|
||||||
|
for key, value in original_env.items():
|
||||||
|
if value is None:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
else:
|
||||||
|
os.environ[key] = value
|
||||||
|
|||||||
Reference in New Issue
Block a user