Fix static typing errors (#187)

Co-authored-by: João Moura <joaomdmoura@gmail.com>
This commit is contained in:
Guilherme Vieira
2024-01-29 19:52:14 -03:00
committed by GitHub
parent cd77981102
commit 29c31a2404
18 changed files with 135 additions and 87 deletions

View File

@@ -1,6 +1,6 @@
name: Lint name: Lint
on: [push, pull_request] on: [pull_request]
jobs: jobs:
lint: lint:

View File

@@ -1,6 +1,6 @@
name: Run Tests name: Run Tests
on: [push, pull_request] on: [pull_request]
permissions: permissions:
contents: write contents: write

30
.github/workflows/type-checker.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: Run Type Checks
on: [pull_request]
permissions:
contents: write
jobs:
type-checker:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install Requirements
run: |
sudo apt-get update &&
pip install poetry &&
poetry lock &&
poetry install
- name: Run type checks
run: poetry run pyright

View File

@@ -208,6 +208,11 @@ pre-commit install
poetry run pytest poetry run pytest
``` ```
### Running static type checks
```bash
poetry run pyright
```
### Packaging ### Packaging
```bash ```bash
poetry build poetry build
@@ -224,5 +229,3 @@ If you are interested on having access to it and hiring weekly hours with our te
## License ## License
CrewAI is released under the MIT License CrewAI is released under the MIT License

55
poetry.lock generated
View File

@@ -204,34 +204,12 @@ dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"]
[[package]] [[package]]
name = "black" name = "black"
version = "23.12.1" version = "24.1.0"
description = "The uncompromising code formatter." description = "The uncompromising code formatter."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = []
{file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"}, develop = false
{file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"},
{file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"},
{file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"},
{file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"},
{file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"},
{file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"},
{file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"},
{file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"},
{file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"},
{file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"},
{file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"},
{file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"},
{file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"},
{file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"},
{file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"},
{file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"},
{file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"},
{file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"},
{file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"},
{file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"},
{file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"},
]
[package.dependencies] [package.dependencies]
click = ">=8.0.0" click = ">=8.0.0"
@@ -248,6 +226,12 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"] uvloop = ["uvloop (>=0.15.2)"]
[package.source]
type = "git"
url = "https://github.com/psf/black.git"
reference = "stable"
resolved_reference = "0e6e46b9eb45f5a22062fe84c2c2ff46bd0d738e"
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2023.11.17" version = "2023.11.17"
@@ -1549,6 +1533,24 @@ pyyaml = "*"
[package.extras] [package.extras]
extra = ["pygments (>=2.12)"] extra = ["pygments (>=2.12)"]
[[package]]
name = "pyright"
version = "1.1.333"
description = "Command line wrapper for pyright"
optional = false
python-versions = ">=3.7"
files = [
{file = "pyright-1.1.333-py3-none-any.whl", hash = "sha256:f0a7b7b0cac11c396b17ef3cf6c8527aca1269edaf5cf8203eed7d6dd1ef52aa"},
{file = "pyright-1.1.333.tar.gz", hash = "sha256:1c49b0029048120c4378f3baf6c1dcbbfb221678bb69654fe773c514430ac53c"},
]
[package.dependencies]
nodeenv = ">=1.6.0"
[package.extras]
all = ["twine (>=3.4.1)"]
dev = ["twine (>=3.4.1)"]
[[package]] [[package]]
name = "pytest" name = "pytest"
version = "7.4.4" version = "7.4.4"
@@ -1639,6 +1641,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -2350,4 +2353,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<4.0" python-versions = ">=3.9,<4.0"
content-hash = "882fc14146aca5bf99c1c0512d74622f35c91a5b3d01d56828b564e4c554eaba" content-hash = "21b3fbe3c3dde7aab6f5a00eae25c97d296fda3ee095914adf2b4b789d36e19d"

View File

@@ -18,13 +18,15 @@ Repository = "https://github.com/joaomdmoura/crewai"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<4.0" python = ">=3.9,<4.0"
pydantic = "^2.4.2" pydantic = "^2.4.2"
langchain = "^0.1.0" langchain = "0.1.0"
openai = "^1.7.1" openai = "^1.7.1"
langchain-openai = "^0.0.2" langchain-openai = "^0.0.2"
pyright = "1.1.333"
black = {git = "https://github.com/psf/black.git", rev = "stable"}
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
isort = "^5.13.2" isort = "^5.13.2"
black = "^23.12.1" black = "^24.1"
autoflake = "^2.2.1" autoflake = "^2.2.1"
pre-commit = "^3.6.0" pre-commit = "^3.6.0"
mkdocs-material = "^9.5.3" mkdocs-material = "^9.5.3"
@@ -33,6 +35,7 @@ mkdocs-material = "^9.5.3"
profile = "black" profile = "black"
known_first_party = ["crewai"] known_first_party = ["crewai"]
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
pytest = "^7.4" pytest = "^7.4"
pytest-vcr = "^1.0.2" pytest-vcr = "^1.0.2"

View File

@@ -2,10 +2,12 @@ import uuid
from typing import Any, List, Optional from typing import Any, List, Optional
from langchain.agents.format_scratchpad import format_log_to_str from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.agent import RunnableAgent
from langchain.memory import ConversationSummaryMemory from langchain.memory import ConversationSummaryMemory
from langchain.tools.render import render_text_description from langchain.tools.render import render_text_description
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_core.language_models import BaseLanguageModel
from pydantic import ( from pydantic import (
UUID4, UUID4,
BaseModel, BaseModel,
@@ -47,7 +49,7 @@ class Agent(BaseModel):
tools: Tools at agents disposal tools: Tools at agents disposal
""" """
__hash__ = object.__hash__ __hash__ = object.__hash__ # type: ignore
_logger: Logger = PrivateAttr() _logger: Logger = PrivateAttr()
_rpm_controller: RPMController = PrivateAttr(default=None) _rpm_controller: RPMController = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None) _request_within_rpm_limit: Any = PrivateAttr(default=None)
@@ -80,21 +82,19 @@ class Agent(BaseModel):
max_iter: Optional[int] = Field( max_iter: Optional[int] = Field(
default=15, description="Maximum iterations for an agent to execute a task" default=15, description="Maximum iterations for an agent to execute a task"
) )
agent_executor: Optional[InstanceOf[CrewAgentExecutor]] = Field( agent_executor: InstanceOf[CrewAgentExecutor] = Field(
default=None, description="An instance of the CrewAgentExecutor class." default=None, description="An instance of the CrewAgentExecutor class."
) )
tools_handler: Optional[InstanceOf[ToolsHandler]] = Field( tools_handler: InstanceOf[ToolsHandler] = Field(
default=None, description="An instance of the ToolsHandler class." default=None, description="An instance of the ToolsHandler class."
) )
cache_handler: Optional[InstanceOf[CacheHandler]] = Field( cache_handler: InstanceOf[CacheHandler] = Field(
default=CacheHandler(), description="An instance of the CacheHandler class." default=CacheHandler(), description="An instance of the CacheHandler class."
) )
i18n: Optional[I18N] = Field( i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
default=I18N(), description="Internationalization settings." llm: Any = Field(
)
llm: Optional[Any] = Field(
default_factory=lambda: ChatOpenAI( default_factory=lambda: ChatOpenAI(
model_name="gpt-4", model="gpt-4",
), ),
description="Language model that will run the agent.", description="Language model that will run the agent.",
) )
@@ -140,6 +140,7 @@ class Agent(BaseModel):
Returns: Returns:
Output of the agent Output of the agent
""" """
if context: if context:
task = self.i18n.slice("task_with_context").format( task = self.i18n.slice("task_with_context").format(
task=task, context=context task=task, context=context
@@ -203,9 +204,9 @@ class Agent(BaseModel):
} }
if self._rpm_controller: if self._rpm_controller:
executor_args[ executor_args["request_within_rpm_limit"] = (
"request_within_rpm_limit" self._rpm_controller.check_or_wait
] = self._rpm_controller.check_or_wait )
if self.memory: if self.memory:
summary_memory = ConversationSummaryMemory( summary_memory = ConversationSummaryMemory(
@@ -234,7 +235,9 @@ class Agent(BaseModel):
i18n=self.i18n, i18n=self.i18n,
) )
) )
self.agent_executor = CrewAgentExecutor(agent=inner_agent, **executor_args) self.agent_executor = CrewAgentExecutor(
agent=RunnableAgent(runnable=inner_agent), **executor_args
)
@staticmethod @staticmethod
def __tools_names(tools) -> str: def __tools_names(tools) -> str:

View File

@@ -1,12 +1,10 @@
from typing import Optional from typing import Optional
from pydantic import PrivateAttr
class CacheHandler: class CacheHandler:
"""Callback handler for tool usage.""" """Callback handler for tool usage."""
_cache: PrivateAttr = {} _cache: dict = {}
def __init__(self): def __init__(self):
self._cache = {} self._cache = {}

View File

@@ -108,8 +108,12 @@ class CrewAgentExecutor(AgentExecutor):
if self._should_force_answer(): if self._should_force_answer():
if isinstance(output, AgentAction): if isinstance(output, AgentAction):
output = output output = output
else: elif isinstance(output, CacheHit):
output = output.action output = output.action
else:
raise ValueError(
f"Unexpected output type from agent: {type(output)}"
)
yield self._force_answer(output) yield self._force_answer(output)
return return

View File

@@ -50,7 +50,6 @@ class CrewAgentOutputParser(ReActSingleInputOutputParser):
i18n: I18N i18n: I18N
def parse(self, text: str) -> Union[AgentAction, AgentFinish, CacheHit]: def parse(self, text: str) -> Union[AgentAction, AgentFinish, CacheHit]:
FINAL_ANSWER_ACTION in text
regex = ( regex = (
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
) )

View File

@@ -10,9 +10,9 @@ class ToolsHandler(BaseCallbackHandler):
"""Callback handler for tool usage.""" """Callback handler for tool usage."""
last_used_tool: Dict[str, Any] = {} last_used_tool: Dict[str, Any] = {}
cache: CacheHandler = None cache: CacheHandler
def __init__(self, cache: CacheHandler = None, **kwargs: Any): def __init__(self, cache: CacheHandler, **kwargs: Any):
"""Initialize the callback handler.""" """Initialize the callback handler."""
self.cache = cache self.cache = cache
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@@ -38,12 +38,10 @@ class Crew(BaseModel):
id: A unique identifier for the crew instance. id: A unique identifier for the crew instance.
""" """
__hash__ = object.__hash__ __hash__ = object.__hash__ # type: ignore
_rpm_controller: RPMController = PrivateAttr() _rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr() _logger: Logger = PrivateAttr()
_cache_handler: Optional[InstanceOf[CacheHandler]] = PrivateAttr( _cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
default=CacheHandler()
)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
tasks: List[Task] = Field(default_factory=list) tasks: List[Task] = Field(default_factory=list)
agents: List[Agent] = Field(default_factory=list) agents: List[Agent] = Field(default_factory=list)
@@ -69,20 +67,20 @@ class Crew(BaseModel):
"may_not_set_field", "The 'id' field cannot be set by the user.", {} "may_not_set_field", "The 'id' field cannot be set by the user.", {}
) )
@classmethod
@field_validator("config", mode="before") @field_validator("config", mode="before")
@classmethod
def check_config_type( def check_config_type(
cls, v: Union[Json, Dict[str, Any]] cls, v: Union[Json, Dict[str, Any]]
) -> Union[Json, Dict[str, Any]]: ) -> Union[Json, Dict[str, Any]]:
"""Validates that the config is a valid type. """Validates that the config is a valid type.
Args: Args:
v: The config to be validated. v: The config to be validated.
Returns: Returns:
The config if it is valid. The config if it is valid.
""" """
return json.loads(v) if isinstance(v, Json) else v
# TODO: Improve typing
return json.loads(v) if isinstance(v, Json) else v # type: ignore
@model_validator(mode="after") @model_validator(mode="after")
def set_private_attrs(self) -> "Crew": def set_private_attrs(self) -> "Crew":
@@ -112,6 +110,8 @@ class Crew(BaseModel):
return self return self
def _setup_from_config(self): def _setup_from_config(self):
assert self.config is not None, "Config should not be None."
"""Initializes agents and tasks from the provided config.""" """Initializes agents and tasks from the provided config."""
if not self.config.get("agents") or not self.config.get("tasks"): if not self.config.get("agents") or not self.config.get("tasks"):
raise PydanticCustomError( raise PydanticCustomError(
@@ -143,19 +143,24 @@ class Crew(BaseModel):
if self.process == Process.sequential: if self.process == Process.sequential:
return self._sequential_loop() return self._sequential_loop()
else:
raise NotImplementedError(
f"The process '{self.process}' is not implemented yet."
)
def _sequential_loop(self) -> str: def _sequential_loop(self) -> str:
"""Executes tasks sequentially and returns the final output.""" """Executes tasks sequentially and returns the final output."""
task_output = None task_output = ""
for task in self.tasks: for task in self.tasks:
self._prepare_and_execute_task(task) self._prepare_and_execute_task(task)
task_output = task.execute(task_output) task_output = task.execute(task_output)
self._logger.log(
"debug", f"[{task.agent.role}] Task output: {task_output}\n\n" role = task.agent.role if task.agent is not None else "None"
) self._logger.log("debug", f"[{role}] Task output: {task_output}\n\n")
if self.max_rpm: if self.max_rpm:
self._rpm_controller.stop_rpm_counter() self._rpm_controller.stop_rpm_counter()
return task_output return task_output
def _prepare_and_execute_task(self, task: Task) -> None: def _prepare_and_execute_task(self, task: Task) -> None:
@@ -164,8 +169,9 @@ class Crew(BaseModel):
Args: Args:
task: The task to be executed. task: The task to be executed.
""" """
if task.agent.allow_delegation: if task.agent is not None and task.agent.allow_delegation:
task.tools += AgentTools(agents=self.agents).tools() task.tools += AgentTools(agents=self.agents).tools()
self._logger.log("debug", f"Working Agent: {task.agent.role}") role = task.agent.role if task.agent is not None else "None"
self._logger.log("debug", f"Working Agent: {role}")
self._logger.log("info", f"Starting Task: {task.description}") self._logger.log("info", f"Starting Task: {task.description}")

View File

@@ -12,7 +12,7 @@ from crewai.utilities import I18N
class Task(BaseModel): class Task(BaseModel):
"""Class that represent a task to be executed.""" """Class that represent a task to be executed."""
__hash__ = object.__hash__ __hash__ = object.__hash__ # type: ignore
i18n: I18N = I18N() i18n: I18N = I18N()
description: str = Field(description="Description of the actual task.") description: str = Field(description="Description of the actual task.")
agent: Optional[Agent] = Field( agent: Optional[Agent] = Field(
@@ -20,7 +20,7 @@ class Task(BaseModel):
) )
tools: List[Any] = Field( tools: List[Any] = Field(
default_factory=list, default_factory=list,
description="Tools the agent are limited to use for this task.", description="Tools the agent is limited to use for this task.",
) )
expected_output: str = Field( expected_output: str = Field(
description="Clear definition of expected output for the task.", description="Clear definition of expected output for the task.",
@@ -46,7 +46,7 @@ class Task(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def check_tools(self): def check_tools(self):
"""Check if the tools are set.""" """Check if the tools are set."""
if not self.tools and (self.agent and self.agent.tools): if not self.tools and self.agent and self.agent.tools:
self.tools.extend(self.agent.tools) self.tools.extend(self.agent.tools)
return self return self

View File

@@ -1,4 +1,4 @@
from typing import List, Optional from typing import List
from langchain.tools import Tool from langchain.tools import Tool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -11,9 +11,7 @@ class AgentTools(BaseModel):
"""Default tools around agent delegation""" """Default tools around agent delegation"""
agents: List[Agent] = Field(description="List of agents in this crew.") agents: List[Agent] = Field(description="List of agents in this crew.")
i18n: Optional[I18N] = Field( i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
default=I18N(), description="Internationalization settings."
)
def tools(self): def tools(self):
return [ return [

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, PrivateAttr, ValidationError, model_valid
class I18N(BaseModel): class I18N(BaseModel):
_translations: Optional[Dict[str, str]] = PrivateAttr() _translations: Dict[str, Dict[str, str]] = PrivateAttr()
language: Optional[str] = Field( language: Optional[str] = Field(
default="en", default="en",
description="Language used to load translations", description="Language used to load translations",
@@ -25,10 +25,14 @@ class I18N(BaseModel):
self._translations = json.load(f) self._translations = json.load(f)
except FileNotFoundError: except FileNotFoundError:
raise ValidationError( raise ValidationError(
f"Trasnlation file for language '{self.language}' not found." f"Translation file for language '{self.language}' not found."
) )
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValidationError(f"Error decoding JSON from the prompts file.") raise ValidationError(f"Error decoding JSON from the prompts file.")
if not self._translations:
self._translations = {}
return self return self
def slice(self, slice: str) -> str: def slice(self, slice: str) -> str:
@@ -40,8 +44,8 @@ class I18N(BaseModel):
def tools(self, error: str) -> str: def tools(self, error: str) -> str:
return self.retrieve("tools", error) return self.retrieve("tools", error)
def retrieve(self, kind, key): def retrieve(self, kind, key) -> str:
try: try:
return self._translations[kind].get(key) return self._translations[kind][key]
except: except:
raise ValidationError(f"Translation for '{kind}':'{key}' not found.") raise ValidationError(f"Translation for '{kind}':'{key}' not found.")

View File

@@ -1,6 +1,6 @@
from typing import ClassVar from typing import ClassVar
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate, BasePromptTemplate
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.utilities import I18N from crewai.utilities import I18N
@@ -13,19 +13,19 @@ class Prompts(BaseModel):
SCRATCHPAD_SLICE: ClassVar[str] = "\n{agent_scratchpad}" SCRATCHPAD_SLICE: ClassVar[str] = "\n{agent_scratchpad}"
def task_execution_with_memory(self) -> str: def task_execution_with_memory(self) -> BasePromptTemplate:
"""Generate a prompt for task execution with memory components.""" """Generate a prompt for task execution with memory components."""
return self._build_prompt(["role_playing", "tools", "memory", "task"]) return self._build_prompt(["role_playing", "tools", "memory", "task"])
def task_execution_without_tools(self) -> str: def task_execution_without_tools(self) -> BasePromptTemplate:
"""Generate a prompt for task execution without tools components.""" """Generate a prompt for task execution without tools components."""
return self._build_prompt(["role_playing", "task"]) return self._build_prompt(["role_playing", "task"])
def task_execution(self) -> str: def task_execution(self) -> BasePromptTemplate:
"""Generate a standard prompt for task execution.""" """Generate a standard prompt for task execution."""
return self._build_prompt(["role_playing", "tools", "task"]) return self._build_prompt(["role_playing", "tools", "task"])
def _build_prompt(self, components: [str]) -> str: def _build_prompt(self, components: list[str]) -> BasePromptTemplate:
"""Constructs a prompt string from specified components.""" """Constructs a prompt string from specified components."""
prompt_parts = [self.i18n.slice(component) for component in components] prompt_parts = [self.i18n.slice(component) for component in components]
prompt_parts.append(self.SCRATCHPAD_SLICE) prompt_parts.append(self.SCRATCHPAD_SLICE)

View File

@@ -12,7 +12,7 @@ class RPMController(BaseModel):
max_rpm: Union[int, None] = Field(default=None) max_rpm: Union[int, None] = Field(default=None)
logger: Logger = Field(default=None) logger: Logger = Field(default=None)
_current_rpm: int = PrivateAttr(default=0) _current_rpm: int = PrivateAttr(default=0)
_timer: threading.Timer = PrivateAttr(default=None) _timer: threading.Timer | None = PrivateAttr(default=None)
_lock: threading.Lock = PrivateAttr(default=None) _lock: threading.Lock = PrivateAttr(default=None)
@model_validator(mode="after") @model_validator(mode="after")

View File

@@ -1,6 +1,5 @@
"""Test Agent creation and execution basic functionality.""" """Test Agent creation and execution basic functionality."""
from crewai.agent import Agent from crewai.agent import Agent
from crewai.task import Task from crewai.task import Task
@@ -51,7 +50,6 @@ def test_task_tool_takes_precedence_ove_agent_tools():
description="Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting.", description="Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting.",
agent=researcher, agent=researcher,
tools=[fake_task_tool], tools=[fake_task_tool],
allow_delegation=False,
) )
assert task.tools == [fake_task_tool] assert task.tools == [fake_task_tool]
@@ -69,7 +67,6 @@ def test_task_prompt_includes_expected_output():
description="Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting.", description="Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting.",
expected_output="Bullet point list of 5 interesting ideas.", expected_output="Bullet point list of 5 interesting ideas.",
agent=researcher, agent=researcher,
allow_delegation=False,
) )
from unittest.mock import patch from unittest.mock import patch