mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Adding Autocomplete to OSS (#1198)
* Cleaned up model_config * Fix pydantic issues * 99% done with autocomplete * fixed test issues * Fix type checking issues
This commit is contained in:
committed by
GitHub
parent
c511f4d0b5
commit
678dfffb62
@@ -113,10 +113,11 @@ class Agent(BaseAgent):
|
|||||||
description="Maximum number of retries for an agent to execute a task when an error occurs.",
|
description="Maximum number of retries for an agent to execute a task when an error occurs.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(__pydantic_self__, **data):
|
@model_validator(mode="after")
|
||||||
config = data.pop("config", {})
|
def set_agent_ops_agent_name(self) -> "Agent":
|
||||||
super().__init__(**config, **data)
|
"""Set agent ops agent name."""
|
||||||
__pydantic_self__.agent_ops_agent_name = __pydantic_self__.role
|
self.agent_ops_agent_name = self.role
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_agent_executor(self) -> "Agent":
|
def set_agent_executor(self) -> "Agent":
|
||||||
@@ -213,7 +214,7 @@ class Agent(BaseAgent):
|
|||||||
raise e
|
raise e
|
||||||
result = self.execute_task(task, context, tools)
|
result = self.execute_task(task, context, tools)
|
||||||
|
|
||||||
if self.max_rpm:
|
if self.max_rpm and self._rpm_controller:
|
||||||
self._rpm_controller.stop_rpm_counter()
|
self._rpm_controller.stop_rpm_counter()
|
||||||
|
|
||||||
# If there was any tool in self.tools_results that had result_as_answer
|
# If there was any tool in self.tools_results that had result_as_answer
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional, TypeVar
|
|||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
|
||||||
Field,
|
Field,
|
||||||
InstanceOf,
|
InstanceOf,
|
||||||
PrivateAttr,
|
PrivateAttr,
|
||||||
@@ -74,12 +73,17 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
__hash__ = object.__hash__ # type: ignore
|
__hash__ = object.__hash__ # type: ignore
|
||||||
_logger: Logger = PrivateAttr()
|
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
|
||||||
_rpm_controller: RPMController = PrivateAttr(default=None)
|
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
|
||||||
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
_request_within_rpm_limit: Any = PrivateAttr(default=None)
|
||||||
formatting_errors: int = 0
|
_original_role: Optional[str] = PrivateAttr(default=None)
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
_original_goal: Optional[str] = PrivateAttr(default=None)
|
||||||
|
_original_backstory: Optional[str] = PrivateAttr(default=None)
|
||||||
|
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||||
|
formatting_errors: int = Field(
|
||||||
|
default=0, description="Number of formatting errors."
|
||||||
|
)
|
||||||
role: str = Field(description="Role of the agent")
|
role: str = Field(description="Role of the agent")
|
||||||
goal: str = Field(description="Objective of the agent")
|
goal: str = Field(description="Objective of the agent")
|
||||||
backstory: str = Field(description="Backstory of the agent")
|
backstory: str = Field(description="Backstory of the agent")
|
||||||
@@ -123,15 +127,6 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
default=None, description="Maximum number of tokens for the agent's execution."
|
default=None, description="Maximum number of tokens for the agent's execution."
|
||||||
)
|
)
|
||||||
|
|
||||||
_original_role: str | None = None
|
|
||||||
_original_goal: str | None = None
|
|
||||||
_original_backstory: str | None = None
|
|
||||||
_token_process: TokenProcess = TokenProcess()
|
|
||||||
|
|
||||||
def __init__(__pydantic_self__, **data):
|
|
||||||
config = data.pop("config", {})
|
|
||||||
super().__init__(**config, **data)
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_config_attributes(self):
|
def set_config_attributes(self):
|
||||||
if self.config:
|
if self.config:
|
||||||
|
|||||||
11
src/crewai/agents/cache/cache_handler.py
vendored
11
src/crewai/agents/cache/cache_handler.py
vendored
@@ -1,13 +1,12 @@
|
|||||||
from typing import Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, PrivateAttr
|
||||||
|
|
||||||
|
|
||||||
class CacheHandler:
|
class CacheHandler(BaseModel):
|
||||||
"""Callback handler for tool usage."""
|
"""Callback handler for tool usage."""
|
||||||
|
|
||||||
_cache: dict = {}
|
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._cache = {}
|
|
||||||
|
|
||||||
def add(self, tool, input, output):
|
def add(self, tool, input, output):
|
||||||
self._cache[f"{tool}-{input}"] = output
|
self._cache[f"{tool}-{input}"] = output
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from langchain_core.callbacks import BaseCallbackHandler
|
|||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
|
||||||
Field,
|
Field,
|
||||||
InstanceOf,
|
InstanceOf,
|
||||||
Json,
|
Json,
|
||||||
@@ -105,7 +104,6 @@ class Crew(BaseModel):
|
|||||||
|
|
||||||
name: Optional[str] = Field(default=None)
|
name: Optional[str] = Field(default=None)
|
||||||
cache: bool = Field(default=True)
|
cache: bool = Field(default=True)
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
tasks: List[Task] = Field(default_factory=list)
|
tasks: List[Task] = Field(default_factory=list)
|
||||||
agents: List[BaseAgent] = Field(default_factory=list)
|
agents: List[BaseAgent] = Field(default_factory=list)
|
||||||
process: Process = Field(default=Process.sequential)
|
process: Process = Field(default=Process.sequential)
|
||||||
|
|||||||
@@ -4,14 +4,12 @@ from typing import Any, Callable, Dict
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import ConfigDict
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
def CrewBase(cls):
|
def CrewBase(cls):
|
||||||
class WrappedClass(cls):
|
class WrappedClass(cls):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
is_crew_class: bool = True # type: ignore
|
is_crew_class: bool = True # type: ignore
|
||||||
|
|
||||||
# Get the directory of the class being decorated
|
# Get the directory of the class being decorated
|
||||||
|
|||||||
@@ -1,24 +1,24 @@
|
|||||||
from typing import Callable, Dict
|
from typing import Any, Callable, Dict, List, Type, Union
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
|
||||||
|
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
from crewai.pipeline.pipeline import Pipeline
|
from crewai.pipeline.pipeline import Pipeline
|
||||||
from crewai.routers.router import Router
|
from crewai.routers.router import Router
|
||||||
|
|
||||||
|
PipelineStage = Union[Crew, List[Crew], Router]
|
||||||
|
|
||||||
|
|
||||||
# TODO: Could potentially remove. Need to check with @joao and @gui if this is needed for CrewAI+
|
# TODO: Could potentially remove. Need to check with @joao and @gui if this is needed for CrewAI+
|
||||||
def PipelineBase(cls):
|
def PipelineBase(cls: Type[Any]) -> Type[Any]:
|
||||||
class WrappedClass(cls):
|
class WrappedClass(cls):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
is_pipeline_class: bool = True # type: ignore
|
is_pipeline_class: bool = True # type: ignore
|
||||||
|
stages: List[PipelineStage]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.stages = []
|
self.stages = []
|
||||||
self._map_pipeline_components()
|
self._map_pipeline_components()
|
||||||
|
|
||||||
def _get_all_functions(self):
|
def _get_all_functions(self) -> Dict[str, Callable[..., Any]]:
|
||||||
return {
|
return {
|
||||||
name: getattr(self, name)
|
name: getattr(self, name)
|
||||||
for name in dir(self)
|
for name in dir(self)
|
||||||
@@ -26,15 +26,15 @@ def PipelineBase(cls):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _filter_functions(
|
def _filter_functions(
|
||||||
self, functions: Dict[str, Callable], attribute: str
|
self, functions: Dict[str, Callable[..., Any]], attribute: str
|
||||||
) -> Dict[str, Callable]:
|
) -> Dict[str, Callable[..., Any]]:
|
||||||
return {
|
return {
|
||||||
name: func
|
name: func
|
||||||
for name, func in functions.items()
|
for name, func in functions.items()
|
||||||
if hasattr(func, attribute)
|
if hasattr(func, attribute)
|
||||||
}
|
}
|
||||||
|
|
||||||
def _map_pipeline_components(self):
|
def _map_pipeline_components(self) -> None:
|
||||||
all_functions = self._get_all_functions()
|
all_functions = self._get_all_functions()
|
||||||
crew_functions = self._filter_functions(all_functions, "is_crew")
|
crew_functions = self._filter_functions(all_functions, "is_crew")
|
||||||
router_functions = self._filter_functions(all_functions, "is_router")
|
router_functions = self._filter_functions(all_functions, "is_router")
|
||||||
|
|||||||
@@ -1,32 +1,26 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Callable, Dict, Generic, Tuple, TypeVar
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
T = TypeVar("T", bound=Dict[str, Any])
|
|
||||||
U = TypeVar("U")
|
class Route(BaseModel):
|
||||||
|
condition: Callable[[Dict[str, Any]], bool]
|
||||||
|
pipeline: Any
|
||||||
|
|
||||||
|
|
||||||
class Route(Generic[T, U]):
|
class Router(BaseModel):
|
||||||
condition: Callable[[T], bool]
|
routes: Dict[str, Route] = Field(
|
||||||
pipeline: U
|
|
||||||
|
|
||||||
def __init__(self, condition: Callable[[T], bool], pipeline: U):
|
|
||||||
self.condition = condition
|
|
||||||
self.pipeline = pipeline
|
|
||||||
|
|
||||||
|
|
||||||
class Router(BaseModel, Generic[T, U]):
|
|
||||||
routes: Dict[str, Route[T, U]] = Field(
|
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Dictionary of route names to (condition, pipeline) tuples",
|
description="Dictionary of route names to (condition, pipeline) tuples",
|
||||||
)
|
)
|
||||||
default: U = Field(..., description="Default pipeline if no conditions are met")
|
default: Any = Field(..., description="Default pipeline if no conditions are met")
|
||||||
_route_types: Dict[str, type] = PrivateAttr(default_factory=dict)
|
_route_types: Dict[str, type] = PrivateAttr(default_factory=dict)
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
def __init__(self, routes: Dict[str, Route[T, U]], default: U, **data):
|
def __init__(self, routes: Dict[str, Route], default: Any, **data):
|
||||||
super().__init__(routes=routes, default=default, **data)
|
super().__init__(routes=routes, default=default, **data)
|
||||||
self._check_copyable(default)
|
self._check_copyable(default)
|
||||||
for name, route in routes.items():
|
for name, route in routes.items():
|
||||||
@@ -34,16 +28,16 @@ class Router(BaseModel, Generic[T, U]):
|
|||||||
self._route_types[name] = type(route.pipeline)
|
self._route_types[name] = type(route.pipeline)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_copyable(obj):
|
def _check_copyable(obj: Any) -> None:
|
||||||
if not hasattr(obj, "copy") or not callable(getattr(obj, "copy")):
|
if not hasattr(obj, "copy") or not callable(getattr(obj, "copy")):
|
||||||
raise ValueError(f"Object of type {type(obj)} must have a 'copy' method")
|
raise ValueError(f"Object of type {type(obj)} must have a 'copy' method")
|
||||||
|
|
||||||
def add_route(
|
def add_route(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
condition: Callable[[T], bool],
|
condition: Callable[[Dict[str, Any]], bool],
|
||||||
pipeline: U,
|
pipeline: Any,
|
||||||
) -> "Router[T, U]":
|
) -> "Router":
|
||||||
"""
|
"""
|
||||||
Add a named route with its condition and corresponding pipeline to the router.
|
Add a named route with its condition and corresponding pipeline to the router.
|
||||||
|
|
||||||
@@ -60,7 +54,7 @@ class Router(BaseModel, Generic[T, U]):
|
|||||||
self._route_types[name] = type(pipeline)
|
self._route_types[name] = type(pipeline)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def route(self, input_data: T) -> Tuple[U, str]:
|
def route(self, input_data: Dict[str, Any]) -> Tuple[Any, str]:
|
||||||
"""
|
"""
|
||||||
Evaluate the input against the conditions and return the appropriate pipeline.
|
Evaluate the input against the conditions and return the appropriate pipeline.
|
||||||
|
|
||||||
@@ -76,15 +70,15 @@ class Router(BaseModel, Generic[T, U]):
|
|||||||
|
|
||||||
return self.default, "default"
|
return self.default, "default"
|
||||||
|
|
||||||
def copy(self) -> "Router[T, U]":
|
def copy(self) -> "Router":
|
||||||
"""Create a deep copy of the Router."""
|
"""Create a deep copy of the Router."""
|
||||||
new_routes = {
|
new_routes = {
|
||||||
name: Route(
|
name: Route(
|
||||||
condition=deepcopy(route.condition),
|
condition=deepcopy(route.condition),
|
||||||
pipeline=route.pipeline.copy(), # type: ignore
|
pipeline=route.pipeline.copy(),
|
||||||
)
|
)
|
||||||
for name, route in self.routes.items()
|
for name, route in self.routes.items()
|
||||||
}
|
}
|
||||||
new_default = self.default.copy() # type: ignore
|
new_default = self.default.copy()
|
||||||
|
|
||||||
return Router(routes=new_routes, default=new_default)
|
return Router(routes=new_routes, default=new_default)
|
||||||
|
|||||||
@@ -9,7 +9,14 @@ from hashlib import md5
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from opentelemetry.trace import Span
|
from opentelemetry.trace import Span
|
||||||
from pydantic import UUID4, BaseModel, Field, field_validator, model_validator
|
from pydantic import (
|
||||||
|
UUID4,
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
PrivateAttr,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
from pydantic_core import PydanticCustomError
|
from pydantic_core import PydanticCustomError
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
@@ -39,9 +46,6 @@ class Task(BaseModel):
|
|||||||
tools: List of tools/resources limited for task execution.
|
tools: List of tools/resources limited for task execution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
__hash__ = object.__hash__ # type: ignore
|
__hash__ = object.__hash__ # type: ignore
|
||||||
used_tools: int = 0
|
used_tools: int = 0
|
||||||
tools_errors: int = 0
|
tools_errors: int = 0
|
||||||
@@ -104,16 +108,12 @@ class Task(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
_telemetry: Telemetry
|
_telemetry: Telemetry = PrivateAttr(default_factory=Telemetry)
|
||||||
_execution_span: Span | None = None
|
_execution_span: Optional[Span] = PrivateAttr(default=None)
|
||||||
_original_description: str | None = None
|
_original_description: Optional[str] = PrivateAttr(default=None)
|
||||||
_original_expected_output: str | None = None
|
_original_expected_output: Optional[str] = PrivateAttr(default=None)
|
||||||
_thread: threading.Thread | None = None
|
_thread: Optional[threading.Thread] = PrivateAttr(default=None)
|
||||||
_execution_time: float | None = None
|
_execution_time: Optional[float] = PrivateAttr(default=None)
|
||||||
|
|
||||||
def __init__(__pydantic_self__, **data):
|
|
||||||
config = data.pop("config", {})
|
|
||||||
super().__init__(**config, **data)
|
|
||||||
|
|
||||||
@field_validator("id", mode="before")
|
@field_validator("id", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -137,12 +137,6 @@ class Task(BaseModel):
|
|||||||
return value[1:]
|
return value[1:]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def set_private_attrs(self) -> "Task":
|
|
||||||
"""Set private attributes."""
|
|
||||||
self._telemetry = Telemetry()
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def set_attributes_based_on_config(self) -> "Task":
|
def set_attributes_based_on_config(self) -> "Task":
|
||||||
"""Set attributes based on the agent configuration."""
|
"""Set attributes based on the agent configuration."""
|
||||||
@@ -263,9 +257,7 @@ class Task(BaseModel):
|
|||||||
content = (
|
content = (
|
||||||
json_output
|
json_output
|
||||||
if json_output
|
if json_output
|
||||||
else pydantic_output.model_dump_json()
|
else pydantic_output.model_dump_json() if pydantic_output else result
|
||||||
if pydantic_output
|
|
||||||
else result
|
|
||||||
)
|
)
|
||||||
self._save_file(content)
|
self._save_file(content)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from langchain.tools import StructuredTool
|
from langchain.tools import StructuredTool
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.agents.cache import CacheHandler
|
from crewai.agents.cache import CacheHandler
|
||||||
|
|
||||||
@@ -7,11 +7,10 @@ from crewai.agents.cache import CacheHandler
|
|||||||
class CacheTools(BaseModel):
|
class CacheTools(BaseModel):
|
||||||
"""Default tools to hit the cache."""
|
"""Default tools to hit the cache."""
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
||||||
name: str = "Hit Cache"
|
name: str = "Hit Cache"
|
||||||
cache_handler: CacheHandler = Field(
|
cache_handler: CacheHandler = Field(
|
||||||
description="Cache Handler for the crew",
|
description="Cache Handler for the crew",
|
||||||
default=CacheHandler(),
|
default_factory=CacheHandler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tool(self):
|
def tool(self):
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
|
|
||||||
|
|
||||||
class Logger:
|
class Logger(BaseModel):
|
||||||
_printer = Printer()
|
verbose: bool = Field(default=False)
|
||||||
|
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||||
def __init__(self, verbose=False):
|
|
||||||
self.verbose = verbose
|
|
||||||
|
|
||||||
def log(self, level, message, color="bold_green"):
|
def log(self, level, message, color="bold_green"):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
|||||||
@@ -1,44 +1,50 @@
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Union
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||||
|
|
||||||
from crewai.utilities.logger import Logger
|
from crewai.utilities.logger import Logger
|
||||||
|
|
||||||
|
|
||||||
class RPMController(BaseModel):
|
class RPMController(BaseModel):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
max_rpm: Optional[int] = Field(default=None)
|
||||||
max_rpm: Union[int, None] = Field(default=None)
|
logger: Logger = Field(default_factory=lambda: Logger(verbose=False))
|
||||||
logger: Logger = Field(default=None)
|
|
||||||
_current_rpm: int = PrivateAttr(default=0)
|
_current_rpm: int = PrivateAttr(default=0)
|
||||||
_timer: threading.Timer | None = PrivateAttr(default=None)
|
_timer: Optional[threading.Timer] = PrivateAttr(default=None)
|
||||||
_lock: threading.Lock = PrivateAttr(default=None)
|
_lock: Optional[threading.Lock] = PrivateAttr(default=None)
|
||||||
_shutdown_flag = False
|
_shutdown_flag: bool = PrivateAttr(default=False)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def reset_counter(self):
|
def reset_counter(self):
|
||||||
if self.max_rpm:
|
if self.max_rpm is not None:
|
||||||
if not self._shutdown_flag:
|
if not self._shutdown_flag:
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._reset_request_count()
|
self._reset_request_count()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def check_or_wait(self):
|
def check_or_wait(self):
|
||||||
if not self.max_rpm:
|
if self.max_rpm is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
with self._lock:
|
def _check_and_increment():
|
||||||
if self._current_rpm < self.max_rpm:
|
if self.max_rpm is not None and self._current_rpm < self.max_rpm:
|
||||||
self._current_rpm += 1
|
self._current_rpm += 1
|
||||||
return True
|
return True
|
||||||
else:
|
elif self.max_rpm is not None:
|
||||||
self.logger.log(
|
self.logger.log(
|
||||||
"info", "Max RPM reached, waiting for next minute to start."
|
"info", "Max RPM reached, waiting for next minute to start."
|
||||||
)
|
)
|
||||||
self._wait_for_next_minute()
|
self._wait_for_next_minute()
|
||||||
self._current_rpm = 1
|
self._current_rpm = 1
|
||||||
return True
|
return True
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self._lock:
|
||||||
|
with self._lock:
|
||||||
|
return _check_and_increment()
|
||||||
|
else:
|
||||||
|
return _check_and_increment()
|
||||||
|
|
||||||
def stop_rpm_counter(self):
|
def stop_rpm_counter(self):
|
||||||
if self._timer:
|
if self._timer:
|
||||||
@@ -50,10 +56,18 @@ class RPMController(BaseModel):
|
|||||||
self._current_rpm = 0
|
self._current_rpm = 0
|
||||||
|
|
||||||
def _reset_request_count(self):
|
def _reset_request_count(self):
|
||||||
with self._lock:
|
def _reset():
|
||||||
self._current_rpm = 0
|
self._current_rpm = 0
|
||||||
|
if not self._shutdown_flag:
|
||||||
|
self._timer = threading.Timer(60.0, self._reset_request_count)
|
||||||
|
self._timer.start()
|
||||||
|
|
||||||
|
if self._lock:
|
||||||
|
with self._lock:
|
||||||
|
_reset()
|
||||||
|
else:
|
||||||
|
_reset()
|
||||||
|
|
||||||
if self._timer:
|
if self._timer:
|
||||||
self._shutdown_flag = True
|
self._shutdown_flag = True
|
||||||
self._timer.cancel()
|
self._timer.cancel()
|
||||||
self._timer = threading.Timer(60.0, self._reset_request_count)
|
|
||||||
self._timer.start()
|
|
||||||
|
|||||||
@@ -4,11 +4,6 @@ from unittest import mock
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain.tools import tool
|
|
||||||
from langchain_core.exceptions import OutputParserException
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain.schema import AgentAction
|
|
||||||
|
|
||||||
from crewai import Agent, Crew, Task
|
from crewai import Agent, Crew, Task
|
||||||
from crewai.agents.cache import CacheHandler
|
from crewai.agents.cache import CacheHandler
|
||||||
from crewai.agents.executor import CrewAgentExecutor
|
from crewai.agents.executor import CrewAgentExecutor
|
||||||
@@ -16,6 +11,10 @@ from crewai.agents.parser import CrewAgentParser
|
|||||||
from crewai.tools.tool_calling import InstructorToolCalling
|
from crewai.tools.tool_calling import InstructorToolCalling
|
||||||
from crewai.tools.tool_usage import ToolUsage
|
from crewai.tools.tool_usage import ToolUsage
|
||||||
from crewai.utilities import RPMController
|
from crewai.utilities import RPMController
|
||||||
|
from langchain.schema import AgentAction
|
||||||
|
from langchain.tools import tool
|
||||||
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
def test_agent_creation():
|
def test_agent_creation():
|
||||||
@@ -817,7 +816,7 @@ def test_agent_definition_based_on_dict():
|
|||||||
"verbose": True,
|
"verbose": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
agent = Agent(config=config)
|
agent = Agent(**config)
|
||||||
|
|
||||||
assert agent.role == "test role"
|
assert agent.role == "test role"
|
||||||
assert agent.goal == "test goal"
|
assert agent.goal == "test goal"
|
||||||
@@ -837,7 +836,7 @@ def test_agent_human_input():
|
|||||||
"backstory": "test backstory",
|
"backstory": "test backstory",
|
||||||
}
|
}
|
||||||
|
|
||||||
agent = Agent(config=config)
|
agent = Agent(**config)
|
||||||
|
|
||||||
task = Task(
|
task = Task(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pydantic_core
|
import pydantic_core
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.agents.cache import CacheHandler
|
from crewai.agents.cache import CacheHandler
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Test Agent creation and execution basic functionality."""
|
"""Test Agent creation and execution basic functionality."""
|
||||||
|
|
||||||
import os
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -703,7 +703,7 @@ def test_task_definition_based_on_dict():
|
|||||||
"expected_output": "The score of the title.",
|
"expected_output": "The score of the title.",
|
||||||
}
|
}
|
||||||
|
|
||||||
task = Task(config=config)
|
task = Task(**config)
|
||||||
|
|
||||||
assert task.description == config["description"]
|
assert task.description == config["description"]
|
||||||
assert task.expected_output == config["expected_output"]
|
assert task.expected_output == config["expected_output"]
|
||||||
@@ -716,7 +716,7 @@ def test_conditional_task_definition_based_on_dict():
|
|||||||
"expected_output": "The score of the title.",
|
"expected_output": "The score of the title.",
|
||||||
}
|
}
|
||||||
|
|
||||||
task = ConditionalTask(config=config, condition=lambda x: True)
|
task = ConditionalTask(**config, condition=lambda x: True)
|
||||||
|
|
||||||
assert task.description == config["description"]
|
assert task.description == config["description"]
|
||||||
assert task.expected_output == config["expected_output"]
|
assert task.expected_output == config["expected_output"]
|
||||||
|
|||||||
Reference in New Issue
Block a user