Compare commits

...

2 Commits

Author SHA1 Message Date
Brandon Hancock
d0d42a76cb fix linting 2024-10-04 16:14:11 -04:00
Brandon Hancock
5fe43b9ad2 reduce import by 6x 2024-10-04 16:05:21 -04:00
7 changed files with 43 additions and 29 deletions

View File

@@ -1,18 +1,19 @@
import os
from inspect import signature
from typing import Any, List, Optional, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from crewai.agents import CacheHandler
from crewai.utilities import Converter, Prompts
from crewai.tools.agent_tools import AgentTools
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.tools.agent_tools import AgentTools
from crewai.utilities import Converter, Prompts
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
def mock_agent_ops_provider():
@@ -292,9 +293,9 @@ class Agent(BaseAgent):
step_callback=self.step_callback,
function_calling_llm=self.function_calling_llm,
respect_context_window=self.respect_context_window,
request_within_rpm_limit=self._rpm_controller.check_or_wait
if self._rpm_controller
else None,
request_within_rpm_limit=(
self._rpm_controller.check_or_wait if self._rpm_controller else None
),
callbacks=[TokenCalcHandler(self._token_process)],
)

View File

@@ -5,11 +5,6 @@ import os
import shutil
from typing import Any, Dict, List, Optional
from embedchain import App
from embedchain.llm.base import BaseLlm
from embedchain.models.data_type import DataType
from embedchain.vectordb.chroma import InvalidDimensionException
from crewai.memory.storage.interface import Storage
from crewai.utilities.paths import db_storage_path
@@ -29,10 +24,6 @@ def suppress_logging(
logger.setLevel(original_level)
class FakeLLM(BaseLlm):
pass
class RAGStorage(Storage):
"""
Extends Storage to handle embeddings for memory entries, improving
@@ -74,9 +65,19 @@ class RAGStorage(Storage):
if embedder_config:
config["embedder"] = embedder_config
self.type = type
self.app = App.from_config(config=config)
self.config = config
self.allow_reset = allow_reset
def _initialize_app(self):
from embedchain import App
from embedchain.llm.base import BaseLlm
class FakeLLM(BaseLlm):
pass
self.app = App.from_config(config=self.config)
self.app.llm = FakeLLM()
if allow_reset:
if self.allow_reset:
self.app.reset()
def _sanitize_role(self, role: str) -> str:
@@ -86,6 +87,8 @@ class RAGStorage(Storage):
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
if not hasattr(self, "app"):
self._initialize_app()
self._generate_embedding(value, metadata)
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
@@ -95,6 +98,10 @@ class RAGStorage(Storage):
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Any]:
if not hasattr(self, "app"):
self._initialize_app()
from embedchain.vectordb.chroma import InvalidDimensionException
with suppress_logging():
try:
results = (
@@ -108,6 +115,10 @@ class RAGStorage(Storage):
return [r for r in results if r["metadata"]["score"] >= score_threshold]
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
if not hasattr(self, "app"):
self._initialize_app()
from embedchain.models.data_type import DataType
self.app.add(text, data_type=DataType.TEXT, metadata=metadata)
def reset(self) -> None:

View File

@@ -21,9 +21,7 @@ with suppress_warnings():
from opentelemetry import trace # noqa: E402
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter, # noqa: E402
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter # noqa: E402
from opentelemetry.sdk.resources import SERVICE_NAME, Resource # noqa: E402
from opentelemetry.sdk.trace import TracerProvider # noqa: E402
from opentelemetry.sdk.trace.export import BatchSpanProcessor # noqa: E402

View File

@@ -1,4 +1,3 @@
from langchain.tools import StructuredTool
from crewai.agents.agent_builder.utilities.base_agent_tool import BaseAgentTools
@@ -6,6 +5,8 @@ class AgentTools(BaseAgentTools):
"""Default tools around agent delegation"""
def tools(self):
from langchain.tools import StructuredTool
coworkers = ", ".join([f"{agent.role}" for agent in self.agents])
tools = [
StructuredTool.from_function(

View File

@@ -1,4 +1,3 @@
from langchain.tools import StructuredTool
from pydantic import BaseModel, Field
from crewai.agents.cache import CacheHandler
@@ -14,6 +13,8 @@ class CacheTools(BaseModel):
)
def tool(self):
from langchain.tools import StructuredTool
return StructuredTool.from_function(
func=self.hit_cache,
name=self.name,

View File

@@ -1,8 +1,5 @@
from typing import Any, Optional, Type
import instructor
from litellm import completion
class InternalInstructor:
"""Class that wraps an agent llm with instructor."""
@@ -28,6 +25,10 @@ class InternalInstructor:
if self.agent and not self.llm:
self.llm = self.agent.function_calling_llm or self.agent.llm
# Lazy import
import instructor
from litellm import completion
self._client = instructor.from_litellm(
completion,
mode=instructor.Mode.TOOLS,

View File

@@ -1,4 +1,5 @@
from litellm.integrations.custom_logger import CustomLogger
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess