mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-17 12:58:31 +00:00
Compare commits
3 Commits
bugfix/mem
...
feat/add-i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9be9d7d61c | ||
|
|
d2db938d50 | ||
|
|
cf90bd3105 |
@@ -330,4 +330,4 @@ This will clear the crew's memory, allowing for a fresh start.
|
||||
|
||||
## Deploying Your Project
|
||||
|
||||
The easiest way to deploy your crew is through [CrewAI Enterprise](http://app.crewai.com/), where you can deploy your crew in a few clicks.
|
||||
The easiest way to deploy your crew is through [CrewAI Enterprise](https://www.crewai.com/crewaiplus), where you can deploy your crew in a few clicks.
|
||||
|
||||
@@ -8,7 +8,6 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
@@ -132,12 +131,8 @@ class Agent(BaseAgent):
|
||||
# If it's already an LLM instance, keep it as is
|
||||
pass
|
||||
elif self.llm is None:
|
||||
# Determine the model name from environment variables or use default
|
||||
model_name = (
|
||||
os.environ.get("OPENAI_MODEL_NAME")
|
||||
or os.environ.get("MODEL")
|
||||
or "gpt-4o-mini"
|
||||
)
|
||||
# If it's None, use environment variables or default
|
||||
model_name = os.environ.get("OPENAI_MODEL_NAME", "gpt-4o-mini")
|
||||
llm_params = {"model": model_name}
|
||||
|
||||
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get(
|
||||
@@ -146,39 +141,9 @@ class Agent(BaseAgent):
|
||||
if api_base:
|
||||
llm_params["base_url"] = api_base
|
||||
|
||||
# Iterate over all environment variables to find matching API keys or use defaults
|
||||
for provider, env_vars in ENV_VARS.items():
|
||||
for env_var in env_vars:
|
||||
# Check if the environment variable is set
|
||||
if "key_name" in env_var:
|
||||
env_value = os.environ.get(env_var["key_name"])
|
||||
if env_value:
|
||||
# Map key names containing "API_KEY" to "api_key"
|
||||
key_name = (
|
||||
"api_key"
|
||||
if "API_KEY" in env_var["key_name"]
|
||||
else env_var["key_name"]
|
||||
)
|
||||
# Map key names containing "API_BASE" to "api_base"
|
||||
key_name = (
|
||||
"api_base"
|
||||
if "API_BASE" in env_var["key_name"]
|
||||
else key_name
|
||||
)
|
||||
# Map key names containing "API_VERSION" to "api_version"
|
||||
key_name = (
|
||||
"api_version"
|
||||
if "API_VERSION" in env_var["key_name"]
|
||||
else key_name
|
||||
)
|
||||
llm_params[key_name] = env_value
|
||||
# Check for default values if the environment variable is not set
|
||||
elif env_var.get("default", False):
|
||||
for key, value in env_var.items():
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
# Only add default if the key is already set in os.environ
|
||||
if key in os.environ:
|
||||
llm_params[key] = value
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
|
||||
self.llm = LLM(**llm_params)
|
||||
else:
|
||||
|
||||
@@ -117,15 +117,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
if answer is None or answer == "":
|
||||
self._printer.print(
|
||||
content="Received None or empty response from LLM call.",
|
||||
color="red",
|
||||
)
|
||||
raise ValueError(
|
||||
"Invalid response from LLM call - None or empty."
|
||||
)
|
||||
|
||||
if not self.use_stop_words:
|
||||
try:
|
||||
self._format_answer(answer)
|
||||
|
||||
@@ -54,7 +54,7 @@ def create_embedded_crew(crew_name: str, parent_folder: Path) -> None:
|
||||
|
||||
templates_dir = Path(__file__).parent / "templates" / "crew"
|
||||
config_template_files = ["agents.yaml", "tasks.yaml"]
|
||||
crew_template_file = f"{folder_name}.py" # Updated file name
|
||||
crew_template_file = f"{folder_name}_crew.py" # Updated file name
|
||||
|
||||
for file_name in config_template_files:
|
||||
src_file = templates_dir / "config" / file_name
|
||||
|
||||
@@ -1,168 +1,19 @@
|
||||
ENV_VARS = {
|
||||
"openai": [
|
||||
{
|
||||
"prompt": "Enter your OPENAI API key (press Enter to skip)",
|
||||
"key_name": "OPENAI_API_KEY",
|
||||
}
|
||||
],
|
||||
"anthropic": [
|
||||
{
|
||||
"prompt": "Enter your ANTHROPIC API key (press Enter to skip)",
|
||||
"key_name": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
],
|
||||
"gemini": [
|
||||
{
|
||||
"prompt": "Enter your GEMINI API key (press Enter to skip)",
|
||||
"key_name": "GEMINI_API_KEY",
|
||||
}
|
||||
],
|
||||
"groq": [
|
||||
{
|
||||
"prompt": "Enter your GROQ API key (press Enter to skip)",
|
||||
"key_name": "GROQ_API_KEY",
|
||||
}
|
||||
],
|
||||
"watson": [
|
||||
{
|
||||
"prompt": "Enter your WATSONX URL (press Enter to skip)",
|
||||
"key_name": "WATSONX_URL",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your WATSONX API Key (press Enter to skip)",
|
||||
"key_name": "WATSONX_APIKEY",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your WATSONX Project Id (press Enter to skip)",
|
||||
"key_name": "WATSONX_PROJECT_ID",
|
||||
},
|
||||
],
|
||||
"ollama": [
|
||||
{
|
||||
"default": True,
|
||||
"API_BASE": "http://localhost:11434",
|
||||
}
|
||||
],
|
||||
"bedrock": [
|
||||
{
|
||||
"prompt": "Enter your AWS Access Key ID (press Enter to skip)",
|
||||
"key_name": "AWS_ACCESS_KEY_ID",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AWS Secret Access Key (press Enter to skip)",
|
||||
"key_name": "AWS_SECRET_ACCESS_KEY",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AWS Region Name (press Enter to skip)",
|
||||
"key_name": "AWS_REGION_NAME",
|
||||
},
|
||||
],
|
||||
"azure": [
|
||||
{
|
||||
"prompt": "Enter your Azure deployment name (must start with 'azure/')",
|
||||
"key_name": "model",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AZURE API key (press Enter to skip)",
|
||||
"key_name": "AZURE_API_KEY",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AZURE API base URL (press Enter to skip)",
|
||||
"key_name": "AZURE_API_BASE",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your AZURE API version (press Enter to skip)",
|
||||
"key_name": "AZURE_API_VERSION",
|
||||
},
|
||||
],
|
||||
"cerebras": [
|
||||
{
|
||||
"prompt": "Enter your Cerebras model name (must start with 'cerebras/')",
|
||||
"key_name": "model",
|
||||
},
|
||||
{
|
||||
"prompt": "Enter your Cerebras API version (press Enter to skip)",
|
||||
"key_name": "CEREBRAS_API_KEY",
|
||||
},
|
||||
],
|
||||
'openai': ['OPENAI_API_KEY'],
|
||||
'anthropic': ['ANTHROPIC_API_KEY'],
|
||||
'gemini': ['GEMINI_API_KEY'],
|
||||
'groq': ['GROQ_API_KEY'],
|
||||
'ollama': ['FAKE_KEY'],
|
||||
}
|
||||
|
||||
|
||||
PROVIDERS = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"gemini",
|
||||
"groq",
|
||||
"ollama",
|
||||
"watson",
|
||||
"bedrock",
|
||||
"azure",
|
||||
"cerebras",
|
||||
]
|
||||
PROVIDERS = ['openai', 'anthropic', 'gemini', 'groq', 'ollama']
|
||||
|
||||
MODELS = {
|
||||
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini", "o1-mini", "o1-preview"],
|
||||
"anthropic": [
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
],
|
||||
"gemini": [
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"gemini/gemini-gemma-2-9b-it",
|
||||
"gemini/gemini-gemma-2-27b-it",
|
||||
],
|
||||
"groq": [
|
||||
"groq/llama-3.1-8b-instant",
|
||||
"groq/llama-3.1-70b-versatile",
|
||||
"groq/llama-3.1-405b-reasoning",
|
||||
"groq/gemma2-9b-it",
|
||||
"groq/gemma-7b-it",
|
||||
],
|
||||
"ollama": ["ollama/llama3.1", "ollama/mixtral"],
|
||||
"watson": [
|
||||
"watsonx/google/flan-t5-xxl",
|
||||
"watsonx/google/flan-ul2",
|
||||
"watsonx/bigscience/mt0-xxl",
|
||||
"watsonx/eleutherai/gpt-neox-20b",
|
||||
"watsonx/ibm/mpt-7b-instruct2",
|
||||
"watsonx/bigcode/starcoder",
|
||||
"watsonx/meta-llama/llama-2-70b-chat",
|
||||
"watsonx/meta-llama/llama-2-13b-chat",
|
||||
"watsonx/ibm/granite-13b-instruct-v1",
|
||||
"watsonx/ibm/granite-13b-chat-v1",
|
||||
"watsonx/google/flan-t5-xl",
|
||||
"watsonx/ibm/granite-13b-chat-v2",
|
||||
"watsonx/ibm/granite-13b-instruct-v2",
|
||||
"watsonx/elyza/elyza-japanese-llama-2-7b-instruct",
|
||||
"watsonx/ibm-mistralai/mixtral-8x7b-instruct-v01-q",
|
||||
],
|
||||
"bedrock": [
|
||||
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
|
||||
"bedrock/anthropic.claude-v2:1",
|
||||
"bedrock/anthropic.claude-v2",
|
||||
"bedrock/anthropic.claude-instant-v1",
|
||||
"bedrock/meta.llama3-1-405b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-1-70b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-1-8b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-70b-instruct-v1:0",
|
||||
"bedrock/meta.llama3-8b-instruct-v1:0",
|
||||
"bedrock/amazon.titan-text-lite-v1",
|
||||
"bedrock/amazon.titan-text-express-v1",
|
||||
"bedrock/cohere.command-text-v14",
|
||||
"bedrock/ai21.j2-mid-v1",
|
||||
"bedrock/ai21.j2-ultra-v1",
|
||||
"bedrock/ai21.jamba-instruct-v1:0",
|
||||
"bedrock/meta.llama2-13b-chat-v1",
|
||||
"bedrock/meta.llama2-70b-chat-v1",
|
||||
"bedrock/mistral.mistral-7b-instruct-v0:2",
|
||||
"bedrock/mistral.mixtral-8x7b-instruct-v0:1",
|
||||
],
|
||||
'openai': ['gpt-4', 'gpt-4o', 'gpt-4o-mini', 'o1-mini', 'o1-preview'],
|
||||
'anthropic': ['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
|
||||
'gemini': ['gemini-1.5-flash', 'gemini-1.5-pro', 'gemini-gemma-2-9b-it', 'gemini-gemma-2-27b-it'],
|
||||
'groq': ['llama-3.1-8b-instant', 'llama-3.1-70b-versatile', 'llama-3.1-405b-reasoning', 'gemma2-9b-it', 'gemma-7b-it'],
|
||||
'ollama': ['llama3.1', 'mixtral'],
|
||||
}
|
||||
|
||||
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
@@ -1,11 +1,11 @@
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
from crewai.cli.constants import ENV_VARS, MODELS
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
from crewai.cli.provider import (
|
||||
PROVIDERS,
|
||||
get_provider_data,
|
||||
select_model,
|
||||
select_provider,
|
||||
@@ -29,20 +29,20 @@ def create_folder_structure(name, parent_folder=None):
|
||||
click.secho("Operation cancelled.", fg="yellow")
|
||||
sys.exit(0)
|
||||
click.secho(f"Overriding folder {folder_name}...", fg="green", bold=True)
|
||||
shutil.rmtree(folder_path) # Delete the existing folder and its contents
|
||||
else:
|
||||
click.secho(
|
||||
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
|
||||
fg="green",
|
||||
bold=True,
|
||||
)
|
||||
|
||||
click.secho(
|
||||
f"Creating {'crew' if parent_folder else 'folder'} {folder_name}...",
|
||||
fg="green",
|
||||
bold=True,
|
||||
)
|
||||
|
||||
folder_path.mkdir(parents=True)
|
||||
(folder_path / "tests").mkdir(exist_ok=True)
|
||||
if not parent_folder:
|
||||
(folder_path / "src" / folder_name).mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "tools").mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True)
|
||||
(folder_path / "tests").mkdir(exist_ok=True)
|
||||
if not parent_folder:
|
||||
(folder_path / "src" / folder_name).mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "tools").mkdir(parents=True)
|
||||
(folder_path / "src" / folder_name / "config").mkdir(parents=True)
|
||||
|
||||
return folder_path, folder_name, class_name
|
||||
|
||||
@@ -92,10 +92,7 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
|
||||
existing_provider = None
|
||||
for provider, env_keys in ENV_VARS.items():
|
||||
if any(
|
||||
"key_name" in details and details["key_name"] in env_vars
|
||||
for details in env_keys
|
||||
):
|
||||
if any(key in env_vars for key in env_keys):
|
||||
existing_provider = provider
|
||||
break
|
||||
|
||||
@@ -121,48 +118,47 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
|
||||
"No provider selected. Please try again or press 'q' to exit.", fg="red"
|
||||
)
|
||||
|
||||
# Check if the selected provider has predefined models
|
||||
if selected_provider in MODELS and MODELS[selected_provider]:
|
||||
while True:
|
||||
selected_model = select_model(selected_provider, provider_models)
|
||||
if selected_model is None: # User typed 'q'
|
||||
click.secho("Exiting...", fg="yellow")
|
||||
sys.exit(0)
|
||||
if selected_model: # Valid selection
|
||||
break
|
||||
click.secho(
|
||||
"No model selected. Please try again or press 'q' to exit.",
|
||||
fg="red",
|
||||
)
|
||||
env_vars["MODEL"] = selected_model
|
||||
|
||||
# Check if the selected provider requires API keys
|
||||
if selected_provider in ENV_VARS:
|
||||
provider_env_vars = ENV_VARS[selected_provider]
|
||||
for details in provider_env_vars:
|
||||
if details.get("default", False):
|
||||
# Automatically add default key-value pairs
|
||||
for key, value in details.items():
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
env_vars[key] = value
|
||||
elif "key_name" in details:
|
||||
# Prompt for non-default key-value pairs
|
||||
prompt = details["prompt"]
|
||||
key_name = details["key_name"]
|
||||
api_key_value = click.prompt(prompt, default="", show_default=False)
|
||||
|
||||
if api_key_value.strip():
|
||||
env_vars[key_name] = api_key_value
|
||||
|
||||
if env_vars:
|
||||
write_env_file(folder_path, env_vars)
|
||||
click.secho("API keys and model saved to .env file", fg="green")
|
||||
else:
|
||||
while True:
|
||||
selected_model = select_model(selected_provider, provider_models)
|
||||
if selected_model is None: # User typed 'q'
|
||||
click.secho("Exiting...", fg="yellow")
|
||||
sys.exit(0)
|
||||
if selected_model: # Valid selection
|
||||
break
|
||||
click.secho(
|
||||
"No API keys provided. Skipping .env file creation.", fg="yellow"
|
||||
"No model selected. Please try again or press 'q' to exit.", fg="red"
|
||||
)
|
||||
|
||||
click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green")
|
||||
if selected_provider in PROVIDERS:
|
||||
api_key_var = ENV_VARS[selected_provider][0]
|
||||
else:
|
||||
api_key_var = click.prompt(
|
||||
f"Enter the environment variable name for your {selected_provider.capitalize()} API key",
|
||||
type=str,
|
||||
default="",
|
||||
)
|
||||
|
||||
api_key_value = ""
|
||||
click.echo(
|
||||
f"Enter your {selected_provider.capitalize()} API key (press Enter to skip): ",
|
||||
nl=False,
|
||||
)
|
||||
try:
|
||||
api_key_value = input()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
api_key_value = ""
|
||||
|
||||
if api_key_value.strip():
|
||||
env_vars = {api_key_var: api_key_value}
|
||||
write_env_file(folder_path, env_vars)
|
||||
click.secho("API key saved to .env file", fg="green")
|
||||
else:
|
||||
click.secho(
|
||||
"No API key provided. Skipping .env file creation.", fg="yellow"
|
||||
)
|
||||
|
||||
env_vars["MODEL"] = selected_model
|
||||
click.secho(f"Selected model: {selected_model}", fg="green")
|
||||
|
||||
package_dir = Path(__file__).parent
|
||||
templates_dir = package_dir / "templates" / "crew"
|
||||
|
||||
@@ -164,7 +164,7 @@ def fetch_provider_data(cache_file):
|
||||
- dict or None: The fetched provider data or None if the operation fails.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(JSON_URL, stream=True, timeout=60)
|
||||
response = requests.get(JSON_URL, stream=True, timeout=10)
|
||||
response.raise_for_status()
|
||||
data = download_data(response)
|
||||
with open(cache_file, "w") as f:
|
||||
|
||||
@@ -8,12 +8,9 @@ from crewai.project import CrewBase, agent, crew, task
|
||||
# from crewai_tools import SerperDevTool
|
||||
|
||||
@CrewBase
|
||||
class {{crew_name}}():
|
||||
class {{crew_name}}Crew():
|
||||
"""{{crew_name}} crew"""
|
||||
|
||||
agents_config = 'config/agents.yaml'
|
||||
tasks_config = 'config/tasks.yaml'
|
||||
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
@@ -51,4 +48,4 @@ class {{crew_name}}():
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
# process=Process.hierarchical, # In case you wanna use that instead https://docs.crewai.com/how-to/Hierarchical/
|
||||
)
|
||||
)
|
||||
@@ -1,10 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from {{folder_name}}.crew import {{crew_name}}
|
||||
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
from {{folder_name}}.crew import {{crew_name}}Crew
|
||||
|
||||
# This main file is intended to be a way for you to run your
|
||||
# crew locally, so refrain from adding unnecessary logic into this file.
|
||||
@@ -18,7 +14,7 @@ def run():
|
||||
inputs = {
|
||||
'topic': 'AI LLMs'
|
||||
}
|
||||
{{crew_name}}().crew().kickoff(inputs=inputs)
|
||||
{{crew_name}}Crew().crew().kickoff(inputs=inputs)
|
||||
|
||||
|
||||
def train():
|
||||
@@ -29,7 +25,7 @@ def train():
|
||||
"topic": "AI LLMs"
|
||||
}
|
||||
try:
|
||||
{{crew_name}}().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
|
||||
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while training the crew: {e}")
|
||||
@@ -39,7 +35,7 @@ def replay():
|
||||
Replay the crew execution from a specific task.
|
||||
"""
|
||||
try:
|
||||
{{crew_name}}().crew().replay(task_id=sys.argv[1])
|
||||
{{crew_name}}Crew().crew().replay(task_id=sys.argv[1])
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while replaying the crew: {e}")
|
||||
@@ -52,7 +48,7 @@ def test():
|
||||
"topic": "AI LLMs"
|
||||
}
|
||||
try:
|
||||
{{crew_name}}().crew().test(n_iterations=int(sys.argv[1]), openai_model_name=sys.argv[2], inputs=inputs)
|
||||
{{crew_name}}Crew().crew().test(n_iterations=int(sys.argv[1]), openai_model_name=sys.argv[2], inputs=inputs)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while replaying the crew: {e}")
|
||||
|
||||
@@ -445,14 +445,13 @@ class Crew(BaseModel):
|
||||
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||
|
||||
for agent in train_crew.agents:
|
||||
if training_data.get(str(agent.id)):
|
||||
result = TaskEvaluator(agent).evaluate_training_data(
|
||||
training_data=training_data, agent_id=str(agent.id)
|
||||
)
|
||||
result = TaskEvaluator(agent).evaluate_training_data(
|
||||
training_data=training_data, agent_id=str(agent.id)
|
||||
)
|
||||
|
||||
CrewTrainingHandler(filename).save_trained_data(
|
||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||
)
|
||||
CrewTrainingHandler(filename).save_trained_data(
|
||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||
)
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
|
||||
@@ -131,6 +131,7 @@ class FlowMeta(type):
|
||||
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
|
||||
# TODO: should we add a check for __condition_type__ 'AND'?
|
||||
elif hasattr(attr_value, "__is_router__"):
|
||||
routers[attr_value.__router_for__] = attr_name
|
||||
possible_returns = get_possible_return_constants(attr_value)
|
||||
@@ -170,7 +171,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
def __init__(self) -> None:
|
||||
self._methods: Dict[str, Callable] = {}
|
||||
self._state: T = self._create_initial_state()
|
||||
self._method_execution_counts: Dict[str, int] = {}
|
||||
self._executed_methods: Set[str] = set()
|
||||
self._scheduled_tasks: Set[str] = set()
|
||||
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
|
||||
@@ -307,10 +309,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
self._method_outputs.append(result) # Store the output
|
||||
|
||||
# Track method execution counts
|
||||
self._method_execution_counts[method_name] = (
|
||||
self._method_execution_counts.get(method_name, 0) + 1
|
||||
)
|
||||
self._executed_methods.add(method_name)
|
||||
|
||||
return result
|
||||
|
||||
@@ -320,34 +319,35 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if trigger_method in self._routers:
|
||||
router_method = self._methods[self._routers[trigger_method]]
|
||||
path = await self._execute_method(
|
||||
self._routers[trigger_method], router_method
|
||||
)
|
||||
trigger_method, router_method
|
||||
) # TODO: Change or not?
|
||||
# Use the path as the new trigger method
|
||||
trigger_method = path
|
||||
|
||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||
if condition_type == "OR":
|
||||
if trigger_method in methods:
|
||||
# Schedule the listener without preventing re-execution
|
||||
listener_tasks.append(
|
||||
self._execute_single_listener(listener_name, result)
|
||||
)
|
||||
if (
|
||||
listener_name not in self._executed_methods
|
||||
and listener_name not in self._scheduled_tasks
|
||||
):
|
||||
self._scheduled_tasks.add(listener_name)
|
||||
listener_tasks.append(
|
||||
self._execute_single_listener(listener_name, result)
|
||||
)
|
||||
elif condition_type == "AND":
|
||||
# Initialize pending methods for this listener if not already done
|
||||
if listener_name not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[listener_name] = set(methods)
|
||||
# Remove the trigger method from pending methods
|
||||
self._pending_and_listeners[listener_name].discard(trigger_method)
|
||||
if not self._pending_and_listeners[listener_name]:
|
||||
# All required methods have been executed
|
||||
listener_tasks.append(
|
||||
self._execute_single_listener(listener_name, result)
|
||||
)
|
||||
# Reset pending methods for this listener
|
||||
self._pending_and_listeners.pop(listener_name, None)
|
||||
if all(method in self._executed_methods for method in methods):
|
||||
if (
|
||||
listener_name not in self._executed_methods
|
||||
and listener_name not in self._scheduled_tasks
|
||||
):
|
||||
self._scheduled_tasks.add(listener_name)
|
||||
listener_tasks.append(
|
||||
self._execute_single_listener(listener_name, result)
|
||||
)
|
||||
|
||||
# Run all listener tasks concurrently and wait for them to complete
|
||||
if listener_tasks:
|
||||
await asyncio.gather(*listener_tasks)
|
||||
await asyncio.gather(*listener_tasks)
|
||||
|
||||
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
|
||||
try:
|
||||
@@ -367,6 +367,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# If listener does not expect parameters, call without arguments
|
||||
listener_result = await self._execute_method(listener_name, method)
|
||||
|
||||
# Remove from scheduled tasks after execution
|
||||
self._scheduled_tasks.discard(listener_name)
|
||||
|
||||
# Execute listeners of this listener
|
||||
await self._execute_listeners(listener_name, listener_result)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
import litellm
|
||||
from litellm import get_supported_openai_params
|
||||
|
||||
@@ -12,6 +9,9 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
|
||||
import sys
|
||||
import io
|
||||
|
||||
|
||||
class FilteredStream(io.StringIO):
|
||||
def write(self, s):
|
||||
|
||||
@@ -34,6 +34,7 @@ class ContextualMemory:
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in stm_results]
|
||||
)
|
||||
print("formatted_results stm", formatted_results)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
||||
@@ -53,6 +54,8 @@ class ContextualMemory:
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
|
||||
print("formatted_results ltm", formatted_results)
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
def _fetch_entity_context(self, query) -> str:
|
||||
@@ -64,4 +67,5 @@ class ContextualMemory:
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
||||
)
|
||||
print("formatted_results em", formatted_results)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
@@ -1,264 +0,0 @@
|
||||
"""Test Flow creation and execution basic functionality."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
|
||||
|
||||
def test_simple_sequential_flow():
|
||||
"""Test a simple flow with two steps called sequentially."""
|
||||
execution_order = []
|
||||
|
||||
class SimpleFlow(Flow):
|
||||
@start()
|
||||
def step_1(self):
|
||||
execution_order.append("step_1")
|
||||
|
||||
@listen(step_1)
|
||||
def step_2(self):
|
||||
execution_order.append("step_2")
|
||||
|
||||
flow = SimpleFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert execution_order == ["step_1", "step_2"]
|
||||
|
||||
|
||||
def test_flow_with_multiple_starts():
|
||||
"""Test a flow with multiple start methods."""
|
||||
execution_order = []
|
||||
|
||||
class MultiStartFlow(Flow):
|
||||
@start()
|
||||
def step_a(self):
|
||||
execution_order.append("step_a")
|
||||
|
||||
@start()
|
||||
def step_b(self):
|
||||
execution_order.append("step_b")
|
||||
|
||||
@listen(step_a)
|
||||
def step_c(self):
|
||||
execution_order.append("step_c")
|
||||
|
||||
@listen(step_b)
|
||||
def step_d(self):
|
||||
execution_order.append("step_d")
|
||||
|
||||
flow = MultiStartFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert "step_a" in execution_order
|
||||
assert "step_b" in execution_order
|
||||
assert "step_c" in execution_order
|
||||
assert "step_d" in execution_order
|
||||
assert execution_order.index("step_c") > execution_order.index("step_a")
|
||||
assert execution_order.index("step_d") > execution_order.index("step_b")
|
||||
|
||||
|
||||
def test_cyclic_flow():
|
||||
"""Test a cyclic flow that runs a finite number of iterations."""
|
||||
execution_order = []
|
||||
|
||||
class CyclicFlow(Flow):
|
||||
iteration = 0
|
||||
max_iterations = 3
|
||||
|
||||
@start("loop")
|
||||
def step_1(self):
|
||||
if self.iteration >= self.max_iterations:
|
||||
return # Do not proceed further
|
||||
execution_order.append(f"step_1_{self.iteration}")
|
||||
|
||||
@listen(step_1)
|
||||
def step_2(self):
|
||||
execution_order.append(f"step_2_{self.iteration}")
|
||||
|
||||
@router(step_2)
|
||||
def step_3(self):
|
||||
execution_order.append(f"step_3_{self.iteration}")
|
||||
self.iteration += 1
|
||||
if self.iteration < self.max_iterations:
|
||||
return "loop"
|
||||
|
||||
return "exit"
|
||||
|
||||
flow = CyclicFlow()
|
||||
flow.kickoff()
|
||||
|
||||
expected_order = []
|
||||
for i in range(flow.max_iterations):
|
||||
expected_order.extend([f"step_1_{i}", f"step_2_{i}", f"step_3_{i}"])
|
||||
|
||||
assert execution_order == expected_order
|
||||
|
||||
|
||||
def test_flow_with_and_condition():
|
||||
"""Test a flow where a step waits for multiple other steps to complete."""
|
||||
execution_order = []
|
||||
|
||||
class AndConditionFlow(Flow):
|
||||
@start()
|
||||
def step_1(self):
|
||||
execution_order.append("step_1")
|
||||
|
||||
@start()
|
||||
def step_2(self):
|
||||
execution_order.append("step_2")
|
||||
|
||||
@listen(and_(step_1, step_2))
|
||||
def step_3(self):
|
||||
execution_order.append("step_3")
|
||||
|
||||
flow = AndConditionFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert "step_1" in execution_order
|
||||
assert "step_2" in execution_order
|
||||
assert execution_order[-1] == "step_3"
|
||||
assert execution_order.index("step_3") > execution_order.index("step_1")
|
||||
assert execution_order.index("step_3") > execution_order.index("step_2")
|
||||
|
||||
|
||||
def test_flow_with_or_condition():
|
||||
"""Test a flow where a step is triggered when any of multiple steps complete."""
|
||||
execution_order = []
|
||||
|
||||
class OrConditionFlow(Flow):
|
||||
@start()
|
||||
def step_a(self):
|
||||
execution_order.append("step_a")
|
||||
|
||||
@start()
|
||||
def step_b(self):
|
||||
execution_order.append("step_b")
|
||||
|
||||
@listen(or_(step_a, step_b))
|
||||
def step_c(self):
|
||||
execution_order.append("step_c")
|
||||
|
||||
flow = OrConditionFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert "step_a" in execution_order or "step_b" in execution_order
|
||||
assert "step_c" in execution_order
|
||||
assert execution_order.index("step_c") > min(
|
||||
execution_order.index("step_a"), execution_order.index("step_b")
|
||||
)
|
||||
|
||||
|
||||
def test_flow_with_router():
|
||||
"""Test a flow that uses a router method to determine the next step."""
|
||||
execution_order = []
|
||||
|
||||
class RouterFlow(Flow):
|
||||
@start()
|
||||
def start_method(self):
|
||||
execution_order.append("start_method")
|
||||
|
||||
@router(start_method)
|
||||
def router(self):
|
||||
execution_order.append("router")
|
||||
# Ensure the condition is set to True to follow the "step_if_true" path
|
||||
condition = True
|
||||
return "step_if_true" if condition else "step_if_false"
|
||||
|
||||
@listen("step_if_true")
|
||||
def truthy(self):
|
||||
execution_order.append("step_if_true")
|
||||
|
||||
@listen("step_if_false")
|
||||
def falsy(self):
|
||||
execution_order.append("step_if_false")
|
||||
|
||||
flow = RouterFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert execution_order == ["start_method", "router", "step_if_true"]
|
||||
|
||||
|
||||
def test_async_flow():
|
||||
"""Test an asynchronous flow."""
|
||||
execution_order = []
|
||||
|
||||
class AsyncFlow(Flow):
|
||||
@start()
|
||||
async def step_1(self):
|
||||
execution_order.append("step_1")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@listen(step_1)
|
||||
async def step_2(self):
|
||||
execution_order.append("step_2")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
flow = AsyncFlow()
|
||||
asyncio.run(flow.kickoff_async())
|
||||
|
||||
assert execution_order == ["step_1", "step_2"]
|
||||
|
||||
|
||||
def test_flow_with_exceptions():
|
||||
"""Test flow behavior when exceptions occur in steps."""
|
||||
execution_order = []
|
||||
|
||||
class ExceptionFlow(Flow):
|
||||
@start()
|
||||
def step_1(self):
|
||||
execution_order.append("step_1")
|
||||
raise ValueError("An error occurred in step_1")
|
||||
|
||||
@listen(step_1)
|
||||
def step_2(self):
|
||||
execution_order.append("step_2")
|
||||
|
||||
flow = ExceptionFlow()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
flow.kickoff()
|
||||
|
||||
# Ensure step_2 did not execute
|
||||
assert execution_order == ["step_1"]
|
||||
|
||||
|
||||
def test_flow_restart():
|
||||
"""Test restarting a flow after it has completed."""
|
||||
execution_order = []
|
||||
|
||||
class RestartableFlow(Flow):
|
||||
@start()
|
||||
def step_1(self):
|
||||
execution_order.append("step_1")
|
||||
|
||||
@listen(step_1)
|
||||
def step_2(self):
|
||||
execution_order.append("step_2")
|
||||
|
||||
flow = RestartableFlow()
|
||||
flow.kickoff()
|
||||
flow.kickoff() # Restart the flow
|
||||
|
||||
assert execution_order == ["step_1", "step_2", "step_1", "step_2"]
|
||||
|
||||
|
||||
def test_flow_with_custom_state():
|
||||
"""Test a flow that maintains and modifies internal state."""
|
||||
|
||||
class StateFlow(Flow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.counter = 0
|
||||
|
||||
@start()
|
||||
def step_1(self):
|
||||
self.counter += 1
|
||||
|
||||
@listen(step_1)
|
||||
def step_2(self):
|
||||
self.counter *= 2
|
||||
assert self.counter == 2
|
||||
|
||||
flow = StateFlow()
|
||||
flow.kickoff()
|
||||
assert flow.counter == 2
|
||||
Reference in New Issue
Block a user