Fix static typing errors (#187)

Co-authored-by: João Moura <joaomdmoura@gmail.com>
This commit is contained in:
Guilherme Vieira
2024-01-29 19:52:14 -03:00
committed by GitHub
parent 66d66bddae
commit e0d97b9916
18 changed files with 135 additions and 87 deletions

View File

@@ -1,10 +1,10 @@
name: Lint
on: [push, pull_request]
on: [pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
- uses: psf/black@stable

View File

@@ -1,6 +1,6 @@
name: Run Tests
on: [push, pull_request]
on: [pull_request]
permissions:
contents: write

30
.github/workflows/type-checker.yml vendored Normal file
View File

@@ -0,0 +1,30 @@
name: Run Type Checks
on: [pull_request]
permissions:
contents: write
jobs:
type-checker:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install Requirements
run: |
sudo apt-get update &&
pip install poetry &&
poetry lock &&
poetry install
- name: Run type checks
run: poetry run pyright

View File

@@ -208,6 +208,11 @@ pre-commit install
poetry run pytest
```
### Running static type checks
```bash
poetry run pyright
```
### Packaging
```bash
poetry build
@@ -224,5 +229,3 @@ If you are interested on having access to it and hiring weekly hours with our te
## License
CrewAI is released under the MIT License

55
poetry.lock generated
View File

@@ -204,34 +204,12 @@ dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"]
[[package]]
name = "black"
version = "23.12.1"
version = "24.1.0"
description = "The uncompromising code formatter."
optional = false
python-versions = ">=3.8"
files = [
{file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"},
{file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"},
{file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"},
{file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"},
{file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"},
{file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"},
{file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"},
{file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"},
{file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"},
{file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"},
{file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"},
{file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"},
{file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"},
{file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"},
{file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"},
{file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"},
{file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"},
{file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"},
{file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"},
{file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"},
{file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"},
{file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"},
]
files = []
develop = false
[package.dependencies]
click = ">=8.0.0"
@@ -248,6 +226,12 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
[package.source]
type = "git"
url = "https://github.com/psf/black.git"
reference = "stable"
resolved_reference = "0e6e46b9eb45f5a22062fe84c2c2ff46bd0d738e"
[[package]]
name = "certifi"
version = "2023.11.17"
@@ -1549,6 +1533,24 @@ pyyaml = "*"
[package.extras]
extra = ["pygments (>=2.12)"]
[[package]]
name = "pyright"
version = "1.1.333"
description = "Command line wrapper for pyright"
optional = false
python-versions = ">=3.7"
files = [
{file = "pyright-1.1.333-py3-none-any.whl", hash = "sha256:f0a7b7b0cac11c396b17ef3cf6c8527aca1269edaf5cf8203eed7d6dd1ef52aa"},
{file = "pyright-1.1.333.tar.gz", hash = "sha256:1c49b0029048120c4378f3baf6c1dcbbfb221678bb69654fe773c514430ac53c"},
]
[package.dependencies]
nodeenv = ">=1.6.0"
[package.extras]
all = ["twine (>=3.4.1)"]
dev = ["twine (>=3.4.1)"]
[[package]]
name = "pytest"
version = "7.4.4"
@@ -1639,6 +1641,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -2350,4 +2353,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<4.0"
content-hash = "882fc14146aca5bf99c1c0512d74622f35c91a5b3d01d56828b564e4c554eaba"
content-hash = "21b3fbe3c3dde7aab6f5a00eae25c97d296fda3ee095914adf2b4b789d36e19d"

View File

@@ -18,13 +18,15 @@ Repository = "https://github.com/joaomdmoura/crewai"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
pydantic = "^2.4.2"
langchain = "^0.1.0"
langchain = "0.1.0"
openai = "^1.7.1"
langchain-openai = "^0.0.2"
pyright = "1.1.333"
black = {git = "https://github.com/psf/black.git", rev = "stable"}
[tool.poetry.group.dev.dependencies]
isort = "^5.13.2"
black = "^23.12.1"
black = "^24.1"
autoflake = "^2.2.1"
pre-commit = "^3.6.0"
mkdocs-material = "^9.5.3"
@@ -33,6 +35,7 @@ mkdocs-material = "^9.5.3"
profile = "black"
known_first_party = ["crewai"]
[tool.poetry.group.test.dependencies]
pytest = "^7.4"
pytest-vcr = "^1.0.2"

View File

@@ -2,10 +2,12 @@ import uuid
from typing import Any, List, Optional
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.agent import RunnableAgent
from langchain.memory import ConversationSummaryMemory
from langchain.tools.render import render_text_description
from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI
from langchain_core.language_models import BaseLanguageModel
from pydantic import (
UUID4,
BaseModel,
@@ -47,7 +49,7 @@ class Agent(BaseModel):
tools: Tools at agents disposal
"""
__hash__ = object.__hash__
__hash__ = object.__hash__ # type: ignore
_logger: Logger = PrivateAttr()
_rpm_controller: RPMController = PrivateAttr(default=None)
_request_within_rpm_limit: Any = PrivateAttr(default=None)
@@ -80,21 +82,19 @@ class Agent(BaseModel):
max_iter: Optional[int] = Field(
default=15, description="Maximum iterations for an agent to execute a task"
)
agent_executor: Optional[InstanceOf[CrewAgentExecutor]] = Field(
agent_executor: InstanceOf[CrewAgentExecutor] = Field(
default=None, description="An instance of the CrewAgentExecutor class."
)
tools_handler: Optional[InstanceOf[ToolsHandler]] = Field(
tools_handler: InstanceOf[ToolsHandler] = Field(
default=None, description="An instance of the ToolsHandler class."
)
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
cache_handler: InstanceOf[CacheHandler] = Field(
default=CacheHandler(), description="An instance of the CacheHandler class."
)
i18n: Optional[I18N] = Field(
default=I18N(), description="Internationalization settings."
)
llm: Optional[Any] = Field(
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
llm: Any = Field(
default_factory=lambda: ChatOpenAI(
model_name="gpt-4",
model="gpt-4",
),
description="Language model that will run the agent.",
)
@@ -140,6 +140,7 @@ class Agent(BaseModel):
Returns:
Output of the agent
"""
if context:
task = self.i18n.slice("task_with_context").format(
task=task, context=context
@@ -203,9 +204,9 @@ class Agent(BaseModel):
}
if self._rpm_controller:
executor_args[
"request_within_rpm_limit"
] = self._rpm_controller.check_or_wait
executor_args["request_within_rpm_limit"] = (
self._rpm_controller.check_or_wait
)
if self.memory:
summary_memory = ConversationSummaryMemory(
@@ -234,7 +235,9 @@ class Agent(BaseModel):
i18n=self.i18n,
)
)
self.agent_executor = CrewAgentExecutor(agent=inner_agent, **executor_args)
self.agent_executor = CrewAgentExecutor(
agent=RunnableAgent(runnable=inner_agent), **executor_args
)
@staticmethod
def __tools_names(tools) -> str:

View File

@@ -1,12 +1,10 @@
from typing import Optional
from pydantic import PrivateAttr
class CacheHandler:
"""Callback handler for tool usage."""
_cache: PrivateAttr = {}
_cache: dict = {}
def __init__(self):
self._cache = {}

View File

@@ -108,8 +108,12 @@ class CrewAgentExecutor(AgentExecutor):
if self._should_force_answer():
if isinstance(output, AgentAction):
output = output
else:
elif isinstance(output, CacheHit):
output = output.action
else:
raise ValueError(
f"Unexpected output type from agent: {type(output)}"
)
yield self._force_answer(output)
return

View File

@@ -50,7 +50,6 @@ class CrewAgentOutputParser(ReActSingleInputOutputParser):
i18n: I18N
def parse(self, text: str) -> Union[AgentAction, AgentFinish, CacheHit]:
FINAL_ANSWER_ACTION in text
regex = (
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
)

View File

@@ -10,9 +10,9 @@ class ToolsHandler(BaseCallbackHandler):
"""Callback handler for tool usage."""
last_used_tool: Dict[str, Any] = {}
cache: CacheHandler = None
cache: CacheHandler
def __init__(self, cache: CacheHandler = None, **kwargs: Any):
def __init__(self, cache: CacheHandler, **kwargs: Any):
"""Initialize the callback handler."""
self.cache = cache
super().__init__(**kwargs)

View File

@@ -38,12 +38,10 @@ class Crew(BaseModel):
id: A unique identifier for the crew instance.
"""
__hash__ = object.__hash__
__hash__ = object.__hash__ # type: ignore
_rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr()
_cache_handler: Optional[InstanceOf[CacheHandler]] = PrivateAttr(
default=CacheHandler()
)
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
model_config = ConfigDict(arbitrary_types_allowed=True)
tasks: List[Task] = Field(default_factory=list)
agents: List[Agent] = Field(default_factory=list)
@@ -69,20 +67,20 @@ class Crew(BaseModel):
"may_not_set_field", "The 'id' field cannot be set by the user.", {}
)
@classmethod
@field_validator("config", mode="before")
@classmethod
def check_config_type(
cls, v: Union[Json, Dict[str, Any]]
) -> Union[Json, Dict[str, Any]]:
"""Validates that the config is a valid type.
Args:
v: The config to be validated.
Returns:
The config if it is valid.
"""
return json.loads(v) if isinstance(v, Json) else v
# TODO: Improve typing
return json.loads(v) if isinstance(v, Json) else v # type: ignore
@model_validator(mode="after")
def set_private_attrs(self) -> "Crew":
@@ -112,6 +110,8 @@ class Crew(BaseModel):
return self
def _setup_from_config(self):
assert self.config is not None, "Config should not be None."
"""Initializes agents and tasks from the provided config."""
if not self.config.get("agents") or not self.config.get("tasks"):
raise PydanticCustomError(
@@ -143,19 +143,24 @@ class Crew(BaseModel):
if self.process == Process.sequential:
return self._sequential_loop()
else:
raise NotImplementedError(
f"The process '{self.process}' is not implemented yet."
)
def _sequential_loop(self) -> str:
"""Executes tasks sequentially and returns the final output."""
task_output = None
task_output = ""
for task in self.tasks:
self._prepare_and_execute_task(task)
task_output = task.execute(task_output)
self._logger.log(
"debug", f"[{task.agent.role}] Task output: {task_output}\n\n"
)
role = task.agent.role if task.agent is not None else "None"
self._logger.log("debug", f"[{role}] Task output: {task_output}\n\n")
if self.max_rpm:
self._rpm_controller.stop_rpm_counter()
return task_output
def _prepare_and_execute_task(self, task: Task) -> None:
@@ -164,8 +169,9 @@ class Crew(BaseModel):
Args:
task: The task to be executed.
"""
if task.agent.allow_delegation:
if task.agent is not None and task.agent.allow_delegation:
task.tools += AgentTools(agents=self.agents).tools()
self._logger.log("debug", f"Working Agent: {task.agent.role}")
role = task.agent.role if task.agent is not None else "None"
self._logger.log("debug", f"Working Agent: {role}")
self._logger.log("info", f"Starting Task: {task.description}")

View File

@@ -12,7 +12,7 @@ from crewai.utilities import I18N
class Task(BaseModel):
"""Class that represent a task to be executed."""
__hash__ = object.__hash__
__hash__ = object.__hash__ # type: ignore
i18n: I18N = I18N()
description: str = Field(description="Description of the actual task.")
agent: Optional[Agent] = Field(
@@ -20,7 +20,7 @@ class Task(BaseModel):
)
tools: List[Any] = Field(
default_factory=list,
description="Tools the agent are limited to use for this task.",
description="Tools the agent is limited to use for this task.",
)
expected_output: str = Field(
description="Clear definition of expected output for the task.",
@@ -46,7 +46,7 @@ class Task(BaseModel):
@model_validator(mode="after")
def check_tools(self):
"""Check if the tools are set."""
if not self.tools and (self.agent and self.agent.tools):
if not self.tools and self.agent and self.agent.tools:
self.tools.extend(self.agent.tools)
return self

View File

@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List
from langchain.tools import Tool
from pydantic import BaseModel, Field
@@ -11,9 +11,7 @@ class AgentTools(BaseModel):
"""Default tools around agent delegation"""
agents: List[Agent] = Field(description="List of agents in this crew.")
i18n: Optional[I18N] = Field(
default=I18N(), description="Internationalization settings."
)
i18n: I18N = Field(default=I18N(), description="Internationalization settings.")
def tools(self):
return [

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, PrivateAttr, ValidationError, model_valid
class I18N(BaseModel):
_translations: Optional[Dict[str, str]] = PrivateAttr()
_translations: Dict[str, Dict[str, str]] = PrivateAttr()
language: Optional[str] = Field(
default="en",
description="Language used to load translations",
@@ -25,10 +25,14 @@ class I18N(BaseModel):
self._translations = json.load(f)
except FileNotFoundError:
raise ValidationError(
f"Trasnlation file for language '{self.language}' not found."
f"Translation file for language '{self.language}' not found."
)
except json.JSONDecodeError:
raise ValidationError(f"Error decoding JSON from the prompts file.")
if not self._translations:
self._translations = {}
return self
def slice(self, slice: str) -> str:
@@ -40,8 +44,8 @@ class I18N(BaseModel):
def tools(self, error: str) -> str:
return self.retrieve("tools", error)
def retrieve(self, kind, key):
def retrieve(self, kind, key) -> str:
try:
return self._translations[kind].get(key)
return self._translations[kind][key]
except:
raise ValidationError(f"Translation for '{kind}':'{key}' not found.")

View File

@@ -1,6 +1,6 @@
from typing import ClassVar
from langchain.prompts import PromptTemplate
from langchain.prompts import PromptTemplate, BasePromptTemplate
from pydantic import BaseModel, Field
from crewai.utilities import I18N
@@ -13,19 +13,19 @@ class Prompts(BaseModel):
SCRATCHPAD_SLICE: ClassVar[str] = "\n{agent_scratchpad}"
def task_execution_with_memory(self) -> str:
def task_execution_with_memory(self) -> BasePromptTemplate:
"""Generate a prompt for task execution with memory components."""
return self._build_prompt(["role_playing", "tools", "memory", "task"])
def task_execution_without_tools(self) -> str:
def task_execution_without_tools(self) -> BasePromptTemplate:
"""Generate a prompt for task execution without tools components."""
return self._build_prompt(["role_playing", "task"])
def task_execution(self) -> str:
def task_execution(self) -> BasePromptTemplate:
"""Generate a standard prompt for task execution."""
return self._build_prompt(["role_playing", "tools", "task"])
def _build_prompt(self, components: [str]) -> str:
def _build_prompt(self, components: list[str]) -> BasePromptTemplate:
"""Constructs a prompt string from specified components."""
prompt_parts = [self.i18n.slice(component) for component in components]
prompt_parts.append(self.SCRATCHPAD_SLICE)

View File

@@ -12,7 +12,7 @@ class RPMController(BaseModel):
max_rpm: Union[int, None] = Field(default=None)
logger: Logger = Field(default=None)
_current_rpm: int = PrivateAttr(default=0)
_timer: threading.Timer = PrivateAttr(default=None)
_timer: threading.Timer | None = PrivateAttr(default=None)
_lock: threading.Lock = PrivateAttr(default=None)
@model_validator(mode="after")

View File

@@ -1,6 +1,5 @@
"""Test Agent creation and execution basic functionality."""
from crewai.agent import Agent
from crewai.task import Task
@@ -51,7 +50,6 @@ def test_task_tool_takes_precedence_ove_agent_tools():
description="Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting.",
agent=researcher,
tools=[fake_task_tool],
allow_delegation=False,
)
assert task.tools == [fake_task_tool]
@@ -69,7 +67,6 @@ def test_task_prompt_includes_expected_output():
description="Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting.",
expected_output="Bullet point list of 5 interesting ideas.",
agent=researcher,
allow_delegation=False,
)
from unittest.mock import patch