mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
fix: Update embedding configuration and fix type errors
- Add configurable embedding providers (OpenAI, Ollama) - Fix type hints in base_tool and structured_tool - Add proper json property implementations - Update documentation for memory configuration - Add environment variables for embedding configuration - Fix type errors in task and crew output classes Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
45
docs/memory.md
Normal file
45
docs/memory.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Memory in CrewAI
|
||||
|
||||
CrewAI provides a robust memory system that allows agents to retain and recall information from previous interactions.
|
||||
|
||||
## Configuring Embedding Providers
|
||||
|
||||
CrewAI supports multiple embedding providers for memory functionality:
|
||||
|
||||
- OpenAI (default) - Requires `OPENAI_API_KEY`
|
||||
- Ollama - Requires `CREWAI_OLLAMA_URL` (defaults to "http://localhost:11434/api/embeddings")
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Configure the embedding provider using these environment variables:
|
||||
|
||||
- `CREWAI_EMBEDDING_PROVIDER`: Provider name (default: "openai")
|
||||
- `CREWAI_EMBEDDING_MODEL`: Model name (default: "text-embedding-3-small")
|
||||
- `CREWAI_OLLAMA_URL`: URL for Ollama API (when using Ollama provider)
|
||||
|
||||
### Example Configuration
|
||||
|
||||
```python
|
||||
# Using OpenAI (default)
|
||||
os.environ["OPENAI_API_KEY"] = "your-api-key"
|
||||
|
||||
# Using Ollama
|
||||
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
|
||||
os.environ["CREWAI_EMBEDDING_MODEL"] = "llama2" # or any other model supported by your Ollama instance
|
||||
os.environ["CREWAI_OLLAMA_URL"] = "http://localhost:11434/api/embeddings" # optional, this is the default
|
||||
```
|
||||
|
||||
## Memory Usage
|
||||
|
||||
When an agent has memory enabled, it can access and store information from previous interactions:
|
||||
|
||||
```python
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research AI topics",
|
||||
backstory="You're an AI researcher",
|
||||
memory=True # Enable memory for this agent
|
||||
)
|
||||
```
|
||||
|
||||
The memory system uses embeddings to store and retrieve relevant information, allowing agents to maintain context across multiple interactions and tasks.
|
||||
@@ -243,6 +243,15 @@ class Agent(BaseAgent):
|
||||
if isinstance(self.knowledge_sources, list) and all(
|
||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||
):
|
||||
# Validate embedding configuration based on provider
|
||||
from crewai.utilities.constants import DEFAULT_EMBEDDING_PROVIDER
|
||||
provider = os.getenv("CREWAI_EMBEDDING_PROVIDER", DEFAULT_EMBEDDING_PROVIDER)
|
||||
|
||||
if provider == "openai" and not os.getenv("OPENAI_API_KEY"):
|
||||
raise ValueError("Please provide an OpenAI API key via OPENAI_API_KEY environment variable")
|
||||
elif provider == "ollama" and not os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings"):
|
||||
raise ValueError("Please provide Ollama URL via CREWAI_OLLAMA_URL environment variable")
|
||||
|
||||
self._knowledge = Knowledge(
|
||||
sources=self.knowledge_sources,
|
||||
embedder_config=self.embedder_config,
|
||||
|
||||
@@ -4,7 +4,7 @@ import uuid
|
||||
import warnings
|
||||
from concurrent.futures import Future
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -797,7 +797,7 @@ class Crew(BaseModel):
|
||||
return skipped_task_output
|
||||
return None
|
||||
|
||||
def _prepare_tools(self, agent: BaseAgent, task: Task, tools: List[Tool]) -> List[Tool]:
|
||||
def _prepare_tools(self, agent: BaseAgent, task: Task, tools: Sequence[Tool]) -> List[Tool]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if agent.allow_delegation:
|
||||
if self.process == Process.hierarchical:
|
||||
@@ -823,7 +823,7 @@ class Crew(BaseModel):
|
||||
return self.manager_agent
|
||||
return task.agent
|
||||
|
||||
def _merge_tools(self, existing_tools: List[Tool], new_tools: List[Tool]) -> List[Tool]:
|
||||
def _merge_tools(self, existing_tools: Sequence[Tool], new_tools: Sequence[Tool]) -> List[Tool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||
if not new_tools:
|
||||
return existing_tools
|
||||
@@ -839,19 +839,19 @@ class Crew(BaseModel):
|
||||
|
||||
return tools
|
||||
|
||||
def _inject_delegation_tools(self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent]):
|
||||
def _inject_delegation_tools(self, tools: Sequence[Tool], task_agent: BaseAgent, agents: Sequence[BaseAgent]):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
return self._merge_tools(tools, delegation_tools)
|
||||
|
||||
def _add_multimodal_tools(self, agent: BaseAgent, tools: List[Tool]):
|
||||
def _add_multimodal_tools(self, agent: BaseAgent, tools: Sequence[Tool]):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
return self._merge_tools(tools, multimodal_tools)
|
||||
|
||||
def _add_code_execution_tools(self, agent: BaseAgent, tools: List[Tool]):
|
||||
def _add_code_execution_tools(self, agent: BaseAgent, tools: Sequence[Tool]):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
return self._merge_tools(tools, code_tools)
|
||||
|
||||
def _add_delegation_tools(self, task: Task, tools: List[Tool]):
|
||||
def _add_delegation_tools(self, task: Task, tools: Sequence[Tool]):
|
||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||
if not tools:
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
# Type definition for include/exclude parameters
|
||||
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
|
||||
|
||||
|
||||
class CrewOutput(BaseModel):
|
||||
"""Class that represents the result of a crew."""
|
||||
@@ -24,13 +28,41 @@ class CrewOutput(BaseModel):
|
||||
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||
def json(self) -> str:
|
||||
"""Get the JSON representation of the output."""
|
||||
if self.tasks_output and self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||
raise ValueError(
|
||||
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||
)
|
||||
return json.dumps(self.json_dict) if self.json_dict else "{}"
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
def model_dump_json(
|
||||
self,
|
||||
*,
|
||||
indent: Optional[int] = None,
|
||||
include: Optional[IncEx] = None,
|
||||
exclude: Optional[IncEx] = None,
|
||||
context: Optional[Any] = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal["none", "warn", "error"] = False,
|
||||
serialize_as_any: bool = False,
|
||||
) -> str:
|
||||
"""Override model_dump_json to handle custom JSON output."""
|
||||
return super().model_dump_json(
|
||||
indent=indent,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
round_trip=round_trip,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
|
||||
@@ -47,7 +47,7 @@ class FastEmbed(BaseEmbedder):
|
||||
cache_dir=str(cache_dir) if cache_dir else None,
|
||||
)
|
||||
|
||||
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
|
||||
def embed_chunks(self, chunks: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of text chunks
|
||||
|
||||
@@ -55,12 +55,12 @@ class FastEmbed(BaseEmbedder):
|
||||
chunks: List of text chunks to embed
|
||||
|
||||
Returns:
|
||||
List of embeddings
|
||||
Array of embeddings
|
||||
"""
|
||||
embeddings = list(self.model.embed(chunks))
|
||||
return embeddings
|
||||
return np.stack(embeddings)
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
|
||||
def embed_texts(self, texts: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Generate embeddings for a list of texts
|
||||
|
||||
@@ -68,10 +68,10 @@ class FastEmbed(BaseEmbedder):
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
List of embeddings
|
||||
Array of embeddings
|
||||
"""
|
||||
embeddings = list(self.model.embed(texts))
|
||||
return embeddings
|
||||
return np.stack(embeddings)
|
||||
|
||||
def embed_text(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
|
||||
@@ -154,8 +154,12 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
filtered_ids.append(doc_id)
|
||||
|
||||
# If we have no metadata at all, set it to None
|
||||
final_metadata: Optional[OneOrMany[chromadb.Metadata]] = (
|
||||
None if all(m is None for m in filtered_metadata) else filtered_metadata
|
||||
final_metadata: Optional[List[Dict[str, Union[str, int, float, bool]]]] = (
|
||||
None if all(m is None for m in filtered_metadata) else [
|
||||
{k: v for k, v in m.items() if isinstance(v, (str, int, float, bool))}
|
||||
if m is not None else None
|
||||
for m in filtered_metadata
|
||||
]
|
||||
)
|
||||
|
||||
self.collection.upsert(
|
||||
|
||||
@@ -15,6 +15,7 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
@@ -250,7 +251,7 @@ class Task(BaseModel):
|
||||
self,
|
||||
agent: Optional[BaseAgent] = None,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
) -> TaskOutput:
|
||||
"""Execute the task synchronously."""
|
||||
return self._execute_core(agent, context, tools)
|
||||
@@ -267,7 +268,7 @@ class Task(BaseModel):
|
||||
self,
|
||||
agent: BaseAgent | None = None,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
) -> Future[TaskOutput]:
|
||||
"""Execute the task asynchronously."""
|
||||
future: Future[TaskOutput] = Future()
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Literal
|
||||
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
|
||||
# Type definition for include/exclude parameters
|
||||
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
|
||||
|
||||
|
||||
class TaskOutput(BaseModel):
|
||||
"""Class that represents the result of a task."""
|
||||
@@ -35,7 +39,8 @@ class TaskOutput(BaseModel):
|
||||
return self
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
def json(self) -> str:
|
||||
"""Get the JSON representation of the output."""
|
||||
if self.output_format != OutputFormat.JSON:
|
||||
raise ValueError(
|
||||
"""
|
||||
@@ -44,8 +49,35 @@ class TaskOutput(BaseModel):
|
||||
please make sure to set the output_json property for the task
|
||||
"""
|
||||
)
|
||||
return json.dumps(self.json_dict) if self.json_dict else "{}"
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
def model_dump_json(
|
||||
self,
|
||||
*,
|
||||
indent: Optional[int] = None,
|
||||
include: Optional[IncEx] = None,
|
||||
exclude: Optional[IncEx] = None,
|
||||
context: Optional[Any] = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal["none", "warn", "error"] = False,
|
||||
serialize_as_any: bool = False,
|
||||
) -> str:
|
||||
"""Override model_dump_json to handle custom JSON output."""
|
||||
return super().model_dump_json(
|
||||
indent=indent,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
round_trip=round_trip,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Type, get_args, get_origin
|
||||
from typing import Any, Callable, Dict, Optional, Type, Tuple, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model, validator
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
def _create_model_fields(fields: Dict[str, Tuple[Any, FieldInfo]]) -> Dict[str, Any]:
|
||||
"""Helper function to create model fields with proper type hints."""
|
||||
return {name: (annotation, field) for name, (annotation, field) in fields.items()}
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
|
||||
@@ -12,7 +17,8 @@ class BaseTool(BaseModel, ABC):
|
||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||
pass
|
||||
|
||||
model_config = ConfigDict()
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
func: Optional[Callable] = None
|
||||
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
@@ -104,20 +110,22 @@ class BaseTool(BaseModel, ABC):
|
||||
description="",
|
||||
)
|
||||
args_fields[name] = (param_annotation, field_info)
|
||||
schema_name = f"{tool.name}Input"
|
||||
if args_fields:
|
||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
||||
model_fields = _create_model_fields(args_fields)
|
||||
args_schema = create_model(schema_name, __base__=PydanticBaseModel, **model_fields)
|
||||
else:
|
||||
# Create a default schema with no fields if no parameters are found
|
||||
args_schema = create_model(
|
||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||
)
|
||||
args_schema = create_model(schema_name, __base__=PydanticBaseModel)
|
||||
|
||||
return cls(
|
||||
tool_instance = cls(
|
||||
name=getattr(tool, "name", "Unnamed Tool"),
|
||||
description=getattr(tool, "description", ""),
|
||||
func=tool.func,
|
||||
args_schema=args_schema,
|
||||
)
|
||||
if hasattr(tool, "func"):
|
||||
tool_instance.func = tool.func
|
||||
return tool_instance
|
||||
|
||||
def _set_args_schema(self):
|
||||
if self.args_schema is None:
|
||||
@@ -171,6 +179,12 @@ class Tool(BaseTool):
|
||||
"""The function that will be executed when the tool is called."""
|
||||
|
||||
func: Callable
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if "func" not in kwargs:
|
||||
raise ValueError("Tool requires a 'func' argument")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
@@ -212,20 +226,22 @@ class Tool(BaseTool):
|
||||
description="",
|
||||
)
|
||||
args_fields[name] = (param_annotation, field_info)
|
||||
schema_name = f"{tool.name}Input"
|
||||
if args_fields:
|
||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
||||
model_fields = _create_model_fields(args_fields)
|
||||
args_schema = create_model(schema_name, __base__=PydanticBaseModel, **model_fields)
|
||||
else:
|
||||
# Create a default schema with no fields if no parameters are found
|
||||
args_schema = create_model(
|
||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||
)
|
||||
args_schema = create_model(schema_name, __base__=PydanticBaseModel)
|
||||
|
||||
return cls(
|
||||
tool_instance = cls(
|
||||
name=getattr(tool, "name", "Unnamed Tool"),
|
||||
description=getattr(tool, "description", ""),
|
||||
func=tool.func,
|
||||
args_schema=args_schema,
|
||||
)
|
||||
if hasattr(tool, "func"):
|
||||
tool_instance.func = tool.func
|
||||
return tool_instance
|
||||
|
||||
|
||||
def to_langchain(
|
||||
|
||||
@@ -2,9 +2,14 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any, Callable, Optional, Union, get_type_hints
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
def _create_model_fields(fields: Dict[str, Tuple[Any, FieldInfo]]) -> Dict[str, Any]:
|
||||
"""Helper function to create model fields with proper type hints."""
|
||||
return {name: (annotation, field) for name, (annotation, field) in fields.items()}
|
||||
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
@@ -142,7 +147,8 @@ class CrewStructuredTool:
|
||||
|
||||
# Create model
|
||||
schema_name = f"{name.title()}Schema"
|
||||
return create_model(schema_name, **fields)
|
||||
model_fields = _create_model_fields(fields)
|
||||
return create_model(schema_name, __base__=BaseModel, **model_fields)
|
||||
|
||||
def _validate_function_signature(self) -> None:
|
||||
"""Validate that the function signature matches the args schema."""
|
||||
|
||||
@@ -4,3 +4,7 @@ DEFAULT_SCORE_THRESHOLD = 0.35
|
||||
KNOWLEDGE_DIRECTORY = "knowledge"
|
||||
MAX_LLM_RETRY = 3
|
||||
MAX_FILE_NAME_LENGTH = 255
|
||||
|
||||
# Default embedding configuration
|
||||
DEFAULT_EMBEDDING_PROVIDER = "openai"
|
||||
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
@@ -47,13 +47,22 @@ class EmbeddingConfigurator:
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function():
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
from crewai.utilities.constants import DEFAULT_EMBEDDING_PROVIDER, DEFAULT_EMBEDDING_MODEL
|
||||
provider = os.getenv("CREWAI_EMBEDDING_PROVIDER", DEFAULT_EMBEDDING_PROVIDER)
|
||||
model = os.getenv("CREWAI_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
if provider == "ollama":
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
|
||||
return OllamaEmbeddingFunction(
|
||||
url=os.getenv("CREWAI_OLLAMA_URL", "http://localhost:11434/api/embeddings"),
|
||||
model_name=model
|
||||
)
|
||||
else:
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model_name=model
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_openai(config, model_name):
|
||||
|
||||
@@ -1,4 +1,30 @@
|
||||
# conftest.py
|
||||
import os
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_result = load_dotenv(override=True)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_env():
|
||||
"""Configure test environment to use Ollama as the default embedding provider."""
|
||||
# Store original environment variables
|
||||
original_env = {
|
||||
"CREWAI_EMBEDDING_PROVIDER": os.environ.get("CREWAI_EMBEDDING_PROVIDER"),
|
||||
"CREWAI_EMBEDDING_MODEL": os.environ.get("CREWAI_EMBEDDING_MODEL"),
|
||||
"CREWAI_OLLAMA_URL": os.environ.get("CREWAI_OLLAMA_URL"),
|
||||
}
|
||||
|
||||
# Set test environment
|
||||
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
|
||||
os.environ["CREWAI_EMBEDDING_MODEL"] = "llama2"
|
||||
os.environ["CREWAI_OLLAMA_URL"] = "http://localhost:11434/api/embeddings"
|
||||
|
||||
yield
|
||||
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = value
|
||||
|
||||
Reference in New Issue
Block a user