reduce import time by 6x (#1396)

* reduce import by 6x

* fix linting
This commit is contained in:
Brandon Hancock (bhancock_ai)
2024-10-06 16:55:32 -04:00
committed by GitHub
parent 0dfe3bcb0a
commit 5d8f8cbc79
7 changed files with 43 additions and 29 deletions

View File

@@ -1,18 +1,19 @@
import os import os
from inspect import signature from inspect import signature
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator from pydantic import Field, InstanceOf, PrivateAttr, model_validator
from crewai.agents import CacheHandler from crewai.agents import CacheHandler
from crewai.utilities import Converter, Prompts
from crewai.tools.agent_tools import AgentTools
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.llm import LLM from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.tools.agent_tools import AgentTools
from crewai.utilities import Converter, Prompts
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
def mock_agent_ops_provider(): def mock_agent_ops_provider():
@@ -292,9 +293,9 @@ class Agent(BaseAgent):
step_callback=self.step_callback, step_callback=self.step_callback,
function_calling_llm=self.function_calling_llm, function_calling_llm=self.function_calling_llm,
respect_context_window=self.respect_context_window, respect_context_window=self.respect_context_window,
request_within_rpm_limit=self._rpm_controller.check_or_wait request_within_rpm_limit=(
if self._rpm_controller self._rpm_controller.check_or_wait if self._rpm_controller else None
else None, ),
callbacks=[TokenCalcHandler(self._token_process)], callbacks=[TokenCalcHandler(self._token_process)],
) )

View File

@@ -5,11 +5,6 @@ import os
import shutil import shutil
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from embedchain import App
from embedchain.llm.base import BaseLlm
from embedchain.models.data_type import DataType
from embedchain.vectordb.chroma import InvalidDimensionException
from crewai.memory.storage.interface import Storage from crewai.memory.storage.interface import Storage
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
@@ -29,10 +24,6 @@ def suppress_logging(
logger.setLevel(original_level) logger.setLevel(original_level)
class FakeLLM(BaseLlm):
pass
class RAGStorage(Storage): class RAGStorage(Storage):
""" """
Extends Storage to handle embeddings for memory entries, improving Extends Storage to handle embeddings for memory entries, improving
@@ -74,9 +65,19 @@ class RAGStorage(Storage):
if embedder_config: if embedder_config:
config["embedder"] = embedder_config config["embedder"] = embedder_config
self.type = type self.type = type
self.app = App.from_config(config=config) self.config = config
self.allow_reset = allow_reset
def _initialize_app(self):
from embedchain import App
from embedchain.llm.base import BaseLlm
class FakeLLM(BaseLlm):
pass
self.app = App.from_config(config=self.config)
self.app.llm = FakeLLM() self.app.llm = FakeLLM()
if allow_reset: if self.allow_reset:
self.app.reset() self.app.reset()
def _sanitize_role(self, role: str) -> str: def _sanitize_role(self, role: str) -> str:
@@ -86,6 +87,8 @@ class RAGStorage(Storage):
return role.replace("\n", "").replace(" ", "_").replace("/", "_") return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: Dict[str, Any]) -> None:
if not hasattr(self, "app"):
self._initialize_app()
self._generate_embedding(value, metadata) self._generate_embedding(value, metadata)
def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage" def search( # type: ignore # BUG?: Signature of "search" incompatible with supertype "Storage"
@@ -95,6 +98,10 @@ class RAGStorage(Storage):
filter: Optional[dict] = None, filter: Optional[dict] = None,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Any]: ) -> List[Any]:
if not hasattr(self, "app"):
self._initialize_app()
from embedchain.vectordb.chroma import InvalidDimensionException
with suppress_logging(): with suppress_logging():
try: try:
results = ( results = (
@@ -108,6 +115,10 @@ class RAGStorage(Storage):
return [r for r in results if r["metadata"]["score"] >= score_threshold] return [r for r in results if r["metadata"]["score"] >= score_threshold]
def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any: def _generate_embedding(self, text: str, metadata: Dict[str, Any]) -> Any:
if not hasattr(self, "app"):
self._initialize_app()
from embedchain.models.data_type import DataType
self.app.add(text, data_type=DataType.TEXT, metadata=metadata) self.app.add(text, data_type=DataType.TEXT, metadata=metadata)
def reset(self) -> None: def reset(self) -> None:

View File

@@ -21,9 +21,7 @@ with suppress_warnings():
from opentelemetry import trace # noqa: E402 from opentelemetry import trace # noqa: E402
from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter # noqa: E402
OTLPSpanExporter, # noqa: E402
)
from opentelemetry.sdk.resources import SERVICE_NAME, Resource # noqa: E402 from opentelemetry.sdk.resources import SERVICE_NAME, Resource # noqa: E402
from opentelemetry.sdk.trace import TracerProvider # noqa: E402 from opentelemetry.sdk.trace import TracerProvider # noqa: E402
from opentelemetry.sdk.trace.export import BatchSpanProcessor # noqa: E402 from opentelemetry.sdk.trace.export import BatchSpanProcessor # noqa: E402

View File

@@ -1,4 +1,3 @@
from langchain.tools import StructuredTool
from crewai.agents.agent_builder.utilities.base_agent_tool import BaseAgentTools from crewai.agents.agent_builder.utilities.base_agent_tool import BaseAgentTools
@@ -6,6 +5,8 @@ class AgentTools(BaseAgentTools):
"""Default tools around agent delegation""" """Default tools around agent delegation"""
def tools(self): def tools(self):
from langchain.tools import StructuredTool
coworkers = ", ".join([f"{agent.role}" for agent in self.agents]) coworkers = ", ".join([f"{agent.role}" for agent in self.agents])
tools = [ tools = [
StructuredTool.from_function( StructuredTool.from_function(

View File

@@ -1,4 +1,3 @@
from langchain.tools import StructuredTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.agents.cache import CacheHandler from crewai.agents.cache import CacheHandler
@@ -14,6 +13,8 @@ class CacheTools(BaseModel):
) )
def tool(self): def tool(self):
from langchain.tools import StructuredTool
return StructuredTool.from_function( return StructuredTool.from_function(
func=self.hit_cache, func=self.hit_cache,
name=self.name, name=self.name,

View File

@@ -1,8 +1,5 @@
from typing import Any, Optional, Type from typing import Any, Optional, Type
import instructor
from litellm import completion
class InternalInstructor: class InternalInstructor:
"""Class that wraps an agent llm with instructor.""" """Class that wraps an agent llm with instructor."""
@@ -28,6 +25,10 @@ class InternalInstructor:
if self.agent and not self.llm: if self.agent and not self.llm:
self.llm = self.agent.function_calling_llm or self.agent.llm self.llm = self.agent.function_calling_llm or self.agent.llm
# Lazy import
import instructor
from litellm import completion
self._client = instructor.from_litellm( self._client = instructor.from_litellm(
completion, completion,
mode=instructor.Mode.TOOLS, mode=instructor.Mode.TOOLS,

View File

@@ -1,4 +1,5 @@
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess