mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 01:58:14 +00:00
Compare commits
5 Commits
feat/docli
...
27472ba69e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27472ba69e | ||
|
|
25aa774d8c | ||
|
|
6cc2f510bf | ||
|
|
9a65abf6b8 | ||
|
|
b3185ad90c |
@@ -79,6 +79,55 @@ crew = Crew(
|
|||||||
result = crew.kickoff(inputs={"question": "What city does John live in and how old is he?"})
|
result = crew.kickoff(inputs={"question": "What city does John live in and how old is he?"})
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Here's another example with the `CrewDoclingSource`
|
||||||
|
```python Code
|
||||||
|
from crewai import LLM, Agent, Crew, Process, Task
|
||||||
|
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
|
||||||
|
|
||||||
|
# Create a knowledge source
|
||||||
|
content_source = CrewDoclingSource(
|
||||||
|
file_paths=[
|
||||||
|
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking",
|
||||||
|
"https://lilianweng.github.io/posts/2024-07-07-hallucination",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an LLM with a temperature of 0 to ensure deterministic outputs
|
||||||
|
llm = LLM(model="gpt-4o-mini", temperature=0)
|
||||||
|
|
||||||
|
# Create an agent with the knowledge store
|
||||||
|
agent = Agent(
|
||||||
|
role="About papers",
|
||||||
|
goal="You know everything about the papers.",
|
||||||
|
backstory="""You are a master at understanding papers and their content.""",
|
||||||
|
verbose=True,
|
||||||
|
allow_delegation=False,
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
task = Task(
|
||||||
|
description="Answer the following questions about the papers: {question}",
|
||||||
|
expected_output="An answer to the question.",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
verbose=True,
|
||||||
|
process=Process.sequential,
|
||||||
|
knowledge_sources=[
|
||||||
|
content_source
|
||||||
|
], # Enable knowledge by adding the sources here. You can also add more sources to the sources list.
|
||||||
|
)
|
||||||
|
|
||||||
|
result = crew.kickoff(
|
||||||
|
inputs={
|
||||||
|
"question": "What is the reward hacking paper about? Be sure to provide sources."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Knowledge Configuration
|
## Knowledge Configuration
|
||||||
|
|
||||||
### Chunking Configuration
|
### Chunking Configuration
|
||||||
|
|||||||
@@ -51,6 +51,9 @@ openpyxl = [
|
|||||||
"openpyxl>=3.1.5",
|
"openpyxl>=3.1.5",
|
||||||
]
|
]
|
||||||
mem0 = ["mem0ai>=0.1.29"]
|
mem0 = ["mem0ai>=0.1.29"]
|
||||||
|
docling = [
|
||||||
|
"docling>=2.12.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
dev-dependencies = [
|
dev-dependencies = [
|
||||||
|
|||||||
@@ -80,10 +80,27 @@ def listen(condition):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def router(method):
|
def router(condition):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
func.__is_router__ = True
|
func.__is_router__ = True
|
||||||
func.__router_for__ = method.__name__
|
# Handle conditions like listen/start
|
||||||
|
if isinstance(condition, str):
|
||||||
|
func.__trigger_methods__ = [condition]
|
||||||
|
func.__condition_type__ = "OR"
|
||||||
|
elif (
|
||||||
|
isinstance(condition, dict)
|
||||||
|
and "type" in condition
|
||||||
|
and "methods" in condition
|
||||||
|
):
|
||||||
|
func.__trigger_methods__ = condition["methods"]
|
||||||
|
func.__condition_type__ = condition["type"]
|
||||||
|
elif callable(condition) and hasattr(condition, "__name__"):
|
||||||
|
func.__trigger_methods__ = [condition.__name__]
|
||||||
|
func.__condition_type__ = "OR"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Condition must be a method, string, or a result of or_() or and_()"
|
||||||
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -123,8 +140,8 @@ class FlowMeta(type):
|
|||||||
|
|
||||||
start_methods = []
|
start_methods = []
|
||||||
listeners = {}
|
listeners = {}
|
||||||
routers = {}
|
|
||||||
router_paths = {}
|
router_paths = {}
|
||||||
|
routers = set()
|
||||||
|
|
||||||
for attr_name, attr_value in dct.items():
|
for attr_name, attr_value in dct.items():
|
||||||
if hasattr(attr_value, "__is_start_method__"):
|
if hasattr(attr_value, "__is_start_method__"):
|
||||||
@@ -137,18 +154,11 @@ class FlowMeta(type):
|
|||||||
methods = attr_value.__trigger_methods__
|
methods = attr_value.__trigger_methods__
|
||||||
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
||||||
listeners[attr_name] = (condition_type, methods)
|
listeners[attr_name] = (condition_type, methods)
|
||||||
|
if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
|
||||||
elif hasattr(attr_value, "__is_router__"):
|
routers.add(attr_name)
|
||||||
routers[attr_value.__router_for__] = attr_name
|
possible_returns = get_possible_return_constants(attr_value)
|
||||||
possible_returns = get_possible_return_constants(attr_value)
|
if possible_returns:
|
||||||
if possible_returns:
|
router_paths[attr_name] = possible_returns
|
||||||
router_paths[attr_name] = possible_returns
|
|
||||||
|
|
||||||
# Register router as a listener to its triggering method
|
|
||||||
trigger_method_name = attr_value.__router_for__
|
|
||||||
methods = [trigger_method_name]
|
|
||||||
condition_type = "OR"
|
|
||||||
listeners[attr_name] = (condition_type, methods)
|
|
||||||
|
|
||||||
setattr(cls, "_start_methods", start_methods)
|
setattr(cls, "_start_methods", start_methods)
|
||||||
setattr(cls, "_listeners", listeners)
|
setattr(cls, "_listeners", listeners)
|
||||||
@@ -163,7 +173,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
_start_methods: List[str] = []
|
_start_methods: List[str] = []
|
||||||
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
||||||
_routers: Dict[str, str] = {}
|
_routers: Set[str] = set()
|
||||||
_router_paths: Dict[str, List[str]] = {}
|
_router_paths: Dict[str, List[str]] = {}
|
||||||
initial_state: Union[Type[T], T, None] = None
|
initial_state: Union[Type[T], T, None] = None
|
||||||
event_emitter = Signal("event_emitter")
|
event_emitter = Signal("event_emitter")
|
||||||
@@ -210,20 +220,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
return self._method_outputs
|
return self._method_outputs
|
||||||
|
|
||||||
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
||||||
"""
|
|
||||||
Initializes or updates the state with the provided inputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: Dictionary of inputs to initialize or update the state.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If inputs do not match the structured state model.
|
|
||||||
TypeError: If state is neither a BaseModel instance nor a dictionary.
|
|
||||||
"""
|
|
||||||
if isinstance(self._state, BaseModel):
|
if isinstance(self._state, BaseModel):
|
||||||
# Structured state management
|
# Structured state
|
||||||
try:
|
try:
|
||||||
# Define a function to create the dynamic class
|
|
||||||
def create_model_with_extra_forbid(
|
def create_model_with_extra_forbid(
|
||||||
base_model: Type[BaseModel],
|
base_model: Type[BaseModel],
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
@@ -233,34 +233,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
return ModelWithExtraForbid
|
return ModelWithExtraForbid
|
||||||
|
|
||||||
# Create the dynamic class
|
|
||||||
ModelWithExtraForbid = create_model_with_extra_forbid(
|
ModelWithExtraForbid = create_model_with_extra_forbid(
|
||||||
self._state.__class__
|
self._state.__class__
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a new instance using the combined state and inputs
|
|
||||||
self._state = cast(
|
self._state = cast(
|
||||||
T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs})
|
T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs})
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise ValueError(f"Invalid inputs for structured state: {e}") from e
|
raise ValueError(f"Invalid inputs for structured state: {e}") from e
|
||||||
elif isinstance(self._state, dict):
|
elif isinstance(self._state, dict):
|
||||||
# Unstructured state management
|
|
||||||
self._state.update(inputs)
|
self._state.update(inputs)
|
||||||
else:
|
else:
|
||||||
raise TypeError("State must be a BaseModel instance or a dictionary.")
|
raise TypeError("State must be a BaseModel instance or a dictionary.")
|
||||||
|
|
||||||
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||||
"""
|
|
||||||
Starts the execution of the flow synchronously.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: Optional dictionary of inputs to initialize or update the state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The final output from the flow execution.
|
|
||||||
"""
|
|
||||||
self.event_emitter.send(
|
self.event_emitter.send(
|
||||||
self,
|
self,
|
||||||
event=FlowStartedEvent(
|
event=FlowStartedEvent(
|
||||||
@@ -274,15 +260,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
return asyncio.run(self.kickoff_async())
|
return asyncio.run(self.kickoff_async())
|
||||||
|
|
||||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||||
"""
|
|
||||||
Starts the execution of the flow asynchronously.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: Optional dictionary of inputs to initialize or update the state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The final output from the flow execution.
|
|
||||||
"""
|
|
||||||
if not self._start_methods:
|
if not self._start_methods:
|
||||||
raise ValueError("No start method defined")
|
raise ValueError("No start method defined")
|
||||||
|
|
||||||
@@ -290,16 +267,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self.__class__.__name__, list(self._methods.keys())
|
self.__class__.__name__, list(self._methods.keys())
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create tasks for all start methods
|
|
||||||
tasks = [
|
tasks = [
|
||||||
self._execute_start_method(start_method)
|
self._execute_start_method(start_method)
|
||||||
for start_method in self._start_methods
|
for start_method in self._start_methods
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run all start methods concurrently
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# Determine the final output (from the last executed method)
|
|
||||||
final_output = self._method_outputs[-1] if self._method_outputs else None
|
final_output = self._method_outputs[-1] if self._method_outputs else None
|
||||||
|
|
||||||
self.event_emitter.send(
|
self.event_emitter.send(
|
||||||
@@ -310,7 +283,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
result=final_output,
|
result=final_output,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return final_output
|
return final_output
|
||||||
|
|
||||||
async def _execute_start_method(self, start_method_name: str) -> None:
|
async def _execute_start_method(self, start_method_name: str) -> None:
|
||||||
@@ -327,49 +299,68 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if asyncio.iscoroutinefunction(method)
|
if asyncio.iscoroutinefunction(method)
|
||||||
else method(*args, **kwargs)
|
else method(*args, **kwargs)
|
||||||
)
|
)
|
||||||
self._method_outputs.append(result) # Store the output
|
self._method_outputs.append(result)
|
||||||
|
|
||||||
# Track method execution counts
|
|
||||||
self._method_execution_counts[method_name] = (
|
self._method_execution_counts[method_name] = (
|
||||||
self._method_execution_counts.get(method_name, 0) + 1
|
self._method_execution_counts.get(method_name, 0) + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
|
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
|
||||||
listener_tasks = []
|
# First, handle routers repeatedly until no router triggers anymore
|
||||||
|
while True:
|
||||||
if trigger_method in self._routers:
|
routers_triggered = self._find_triggered_methods(
|
||||||
router_method = self._methods[self._routers[trigger_method]]
|
trigger_method, router_only=True
|
||||||
path = await self._execute_method(
|
|
||||||
self._routers[trigger_method], router_method
|
|
||||||
)
|
)
|
||||||
trigger_method = path
|
if not routers_triggered:
|
||||||
|
break
|
||||||
|
for router_name in routers_triggered:
|
||||||
|
await self._execute_single_listener(router_name, result)
|
||||||
|
# After executing router, the router's result is the path
|
||||||
|
# The last router executed sets the trigger_method
|
||||||
|
# The router result is the last element in self._method_outputs
|
||||||
|
trigger_method = self._method_outputs[-1]
|
||||||
|
|
||||||
|
# Now that no more routers are triggered by current trigger_method,
|
||||||
|
# execute normal listeners
|
||||||
|
listeners_triggered = self._find_triggered_methods(
|
||||||
|
trigger_method, router_only=False
|
||||||
|
)
|
||||||
|
if listeners_triggered:
|
||||||
|
tasks = [
|
||||||
|
self._execute_single_listener(listener_name, result)
|
||||||
|
for listener_name in listeners_triggered
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
def _find_triggered_methods(
|
||||||
|
self, trigger_method: str, router_only: bool
|
||||||
|
) -> List[str]:
|
||||||
|
triggered = []
|
||||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||||
|
is_router = listener_name in self._routers
|
||||||
|
|
||||||
|
if router_only != is_router:
|
||||||
|
continue
|
||||||
|
|
||||||
if condition_type == "OR":
|
if condition_type == "OR":
|
||||||
|
# If the trigger_method matches any in methods, run this
|
||||||
if trigger_method in methods:
|
if trigger_method in methods:
|
||||||
# Schedule the listener without preventing re-execution
|
triggered.append(listener_name)
|
||||||
listener_tasks.append(
|
|
||||||
self._execute_single_listener(listener_name, result)
|
|
||||||
)
|
|
||||||
elif condition_type == "AND":
|
elif condition_type == "AND":
|
||||||
# Initialize pending methods for this listener if not already done
|
# Initialize pending methods for this listener if not already done
|
||||||
if listener_name not in self._pending_and_listeners:
|
if listener_name not in self._pending_and_listeners:
|
||||||
self._pending_and_listeners[listener_name] = set(methods)
|
self._pending_and_listeners[listener_name] = set(methods)
|
||||||
# Remove the trigger method from pending methods
|
# Remove the trigger method from pending methods
|
||||||
self._pending_and_listeners[listener_name].discard(trigger_method)
|
if trigger_method in self._pending_and_listeners[listener_name]:
|
||||||
|
self._pending_and_listeners[listener_name].discard(trigger_method)
|
||||||
|
|
||||||
if not self._pending_and_listeners[listener_name]:
|
if not self._pending_and_listeners[listener_name]:
|
||||||
# All required methods have been executed
|
# All required methods have been executed
|
||||||
listener_tasks.append(
|
triggered.append(listener_name)
|
||||||
self._execute_single_listener(listener_name, result)
|
|
||||||
)
|
|
||||||
# Reset pending methods for this listener
|
# Reset pending methods for this listener
|
||||||
self._pending_and_listeners.pop(listener_name, None)
|
self._pending_and_listeners.pop(listener_name, None)
|
||||||
|
|
||||||
# Run all listener tasks concurrently and wait for them to complete
|
return triggered
|
||||||
if listener_tasks:
|
|
||||||
await asyncio.gather(*listener_tasks)
|
|
||||||
|
|
||||||
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
|
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -386,17 +377,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
sig = inspect.signature(method)
|
sig = inspect.signature(method)
|
||||||
params = list(sig.parameters.values())
|
params = list(sig.parameters.values())
|
||||||
|
|
||||||
# Exclude 'self' parameter
|
|
||||||
method_params = [p for p in params if p.name != "self"]
|
method_params = [p for p in params if p.name != "self"]
|
||||||
|
|
||||||
if method_params:
|
if method_params:
|
||||||
# If listener expects parameters, pass the result
|
|
||||||
listener_result = await self._execute_method(
|
listener_result = await self._execute_method(
|
||||||
listener_name, method, result
|
listener_name, method, result
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# If listener does not expect parameters, call without arguments
|
|
||||||
listener_result = await self._execute_method(listener_name, method)
|
listener_result = await self._execute_method(listener_name, method)
|
||||||
|
|
||||||
self.event_emitter.send(
|
self.event_emitter.send(
|
||||||
@@ -408,8 +395,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute listeners of this listener
|
# Execute listeners (and possibly routers) of this listener
|
||||||
await self._execute_listeners(listener_name, listener_result)
|
await self._execute_listeners(listener_name, listener_result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(
|
||||||
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
|
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
|
||||||
@@ -422,5 +410,4 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self._telemetry.flow_plotting_span(
|
self._telemetry.flow_plotting_span(
|
||||||
self.__class__.__name__, list(self._methods.keys())
|
self.__class__.__name__, list(self._methods.keys())
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_flow(self, filename)
|
plot_flow(self, filename)
|
||||||
|
|||||||
@@ -31,16 +31,50 @@ def get_possible_return_constants(function):
|
|||||||
print(f"Source code:\n{source}")
|
print(f"Source code:\n{source}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return_values = []
|
return_values = set()
|
||||||
|
dict_definitions = {}
|
||||||
|
|
||||||
|
class DictionaryAssignmentVisitor(ast.NodeVisitor):
|
||||||
|
def visit_Assign(self, node):
|
||||||
|
# Check if this assignment is assigning a dictionary literal to a variable
|
||||||
|
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
|
||||||
|
target = node.targets[0]
|
||||||
|
if isinstance(target, ast.Name):
|
||||||
|
var_name = target.id
|
||||||
|
dict_values = []
|
||||||
|
# Extract string values from the dictionary
|
||||||
|
for val in node.value.values:
|
||||||
|
if isinstance(val, ast.Constant) and isinstance(val.value, str):
|
||||||
|
dict_values.append(val.value)
|
||||||
|
# If non-string, skip or just ignore
|
||||||
|
if dict_values:
|
||||||
|
dict_definitions[var_name] = dict_values
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
class ReturnVisitor(ast.NodeVisitor):
|
class ReturnVisitor(ast.NodeVisitor):
|
||||||
def visit_Return(self, node):
|
def visit_Return(self, node):
|
||||||
# Check if the return value is a constant (Python 3.8+)
|
# Direct string return
|
||||||
if isinstance(node.value, ast.Constant):
|
if isinstance(node.value, ast.Constant) and isinstance(
|
||||||
return_values.append(node.value.value)
|
node.value.value, str
|
||||||
|
):
|
||||||
|
return_values.add(node.value.value)
|
||||||
|
# Dictionary-based return, like return paths[result]
|
||||||
|
elif isinstance(node.value, ast.Subscript):
|
||||||
|
# Check if we're subscripting a known dictionary variable
|
||||||
|
if isinstance(node.value.value, ast.Name):
|
||||||
|
var_name = node.value.value.id
|
||||||
|
if var_name in dict_definitions:
|
||||||
|
# Add all possible dictionary values
|
||||||
|
for v in dict_definitions[var_name]:
|
||||||
|
return_values.add(v)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
# First pass: identify dictionary assignments
|
||||||
|
DictionaryAssignmentVisitor().visit(code_ast)
|
||||||
|
# Second pass: identify returns
|
||||||
ReturnVisitor().visit(code_ast)
|
ReturnVisitor().visit(code_ast)
|
||||||
return return_values
|
|
||||||
|
return list(return_values) if return_values else None
|
||||||
|
|
||||||
|
|
||||||
def calculate_node_levels(flow):
|
def calculate_node_levels(flow):
|
||||||
@@ -61,10 +95,7 @@ def calculate_node_levels(flow):
|
|||||||
current_level = levels[current]
|
current_level = levels[current]
|
||||||
visited.add(current)
|
visited.add(current)
|
||||||
|
|
||||||
for listener_name, (
|
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
|
||||||
condition_type,
|
|
||||||
trigger_methods,
|
|
||||||
) in flow._listeners.items():
|
|
||||||
if condition_type == "OR":
|
if condition_type == "OR":
|
||||||
if current in trigger_methods:
|
if current in trigger_methods:
|
||||||
if (
|
if (
|
||||||
@@ -89,7 +120,7 @@ def calculate_node_levels(flow):
|
|||||||
queue.append(listener_name)
|
queue.append(listener_name)
|
||||||
|
|
||||||
# Handle router connections
|
# Handle router connections
|
||||||
if current in flow._routers.values():
|
if current in flow._routers:
|
||||||
router_method_name = current
|
router_method_name = current
|
||||||
paths = flow._router_paths.get(router_method_name, [])
|
paths = flow._router_paths.get(router_method_name, [])
|
||||||
for path in paths:
|
for path in paths:
|
||||||
@@ -105,6 +136,7 @@ def calculate_node_levels(flow):
|
|||||||
levels[listener_name] = current_level + 1
|
levels[listener_name] = current_level + 1
|
||||||
if listener_name not in visited:
|
if listener_name not in visited:
|
||||||
queue.append(listener_name)
|
queue.append(listener_name)
|
||||||
|
|
||||||
return levels
|
return levels
|
||||||
|
|
||||||
|
|
||||||
@@ -142,7 +174,7 @@ def dfs_ancestors(node, ancestors, visited, flow):
|
|||||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||||
|
|
||||||
# Handle router methods separately
|
# Handle router methods separately
|
||||||
if node in flow._routers.values():
|
if node in flow._routers:
|
||||||
router_method_name = node
|
router_method_name = node
|
||||||
paths = flow._router_paths.get(router_method_name, [])
|
paths = flow._router_paths.get(router_method_name, [])
|
||||||
for path in paths:
|
for path in paths:
|
||||||
|
|||||||
@@ -94,12 +94,14 @@ def add_edges(net, flow, node_positions, colors):
|
|||||||
ancestors = build_ancestor_dict(flow)
|
ancestors = build_ancestor_dict(flow)
|
||||||
parent_children = build_parent_children_dict(flow)
|
parent_children = build_parent_children_dict(flow)
|
||||||
|
|
||||||
|
# Edges for normal listeners
|
||||||
for method_name in flow._listeners:
|
for method_name in flow._listeners:
|
||||||
condition_type, trigger_methods = flow._listeners[method_name]
|
condition_type, trigger_methods = flow._listeners[method_name]
|
||||||
is_and_condition = condition_type == "AND"
|
is_and_condition = condition_type == "AND"
|
||||||
|
|
||||||
for trigger in trigger_methods:
|
for trigger in trigger_methods:
|
||||||
if trigger in flow._methods or trigger in flow._routers.values():
|
# Check if nodes exist before adding edges
|
||||||
|
if trigger in node_positions and method_name in node_positions:
|
||||||
is_router_edge = any(
|
is_router_edge = any(
|
||||||
trigger in paths for paths in flow._router_paths.values()
|
trigger in paths for paths in flow._router_paths.values()
|
||||||
)
|
)
|
||||||
@@ -135,7 +137,22 @@ def add_edges(net, flow, node_positions, colors):
|
|||||||
}
|
}
|
||||||
|
|
||||||
net.add_edge(trigger, method_name, **edge_style)
|
net.add_edge(trigger, method_name, **edge_style)
|
||||||
|
else:
|
||||||
|
# Nodes not found in node_positions. Check if it's a known router outcome and a known method.
|
||||||
|
is_router_edge = any(
|
||||||
|
trigger in paths for paths in flow._router_paths.values()
|
||||||
|
)
|
||||||
|
# Check if method_name is a known method
|
||||||
|
method_known = method_name in flow._methods
|
||||||
|
|
||||||
|
# If it's a known router edge and the method is known, don't warn.
|
||||||
|
# This means the path is legitimate, just not reflected as nodes here.
|
||||||
|
if not (is_router_edge and method_known):
|
||||||
|
print(
|
||||||
|
f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Edges for router return paths
|
||||||
for router_method_name, paths in flow._router_paths.items():
|
for router_method_name, paths in flow._router_paths.items():
|
||||||
for path in paths:
|
for path in paths:
|
||||||
for listener_name, (
|
for listener_name, (
|
||||||
@@ -143,36 +160,49 @@ def add_edges(net, flow, node_positions, colors):
|
|||||||
trigger_methods,
|
trigger_methods,
|
||||||
) in flow._listeners.items():
|
) in flow._listeners.items():
|
||||||
if path in trigger_methods:
|
if path in trigger_methods:
|
||||||
is_cycle_edge = is_ancestor(trigger, method_name, ancestors)
|
if (
|
||||||
parent_has_multiple_children = (
|
router_method_name in node_positions
|
||||||
len(parent_children.get(router_method_name, [])) > 1
|
and listener_name in node_positions
|
||||||
)
|
):
|
||||||
needs_curvature = is_cycle_edge or parent_has_multiple_children
|
is_cycle_edge = is_ancestor(
|
||||||
|
router_method_name, listener_name, ancestors
|
||||||
|
)
|
||||||
|
parent_has_multiple_children = (
|
||||||
|
len(parent_children.get(router_method_name, [])) > 1
|
||||||
|
)
|
||||||
|
needs_curvature = is_cycle_edge or parent_has_multiple_children
|
||||||
|
|
||||||
if needs_curvature:
|
if needs_curvature:
|
||||||
source_pos = node_positions.get(router_method_name)
|
source_pos = node_positions.get(router_method_name)
|
||||||
target_pos = node_positions.get(listener_name)
|
target_pos = node_positions.get(listener_name)
|
||||||
|
|
||||||
if source_pos and target_pos:
|
if source_pos and target_pos:
|
||||||
dx = target_pos[0] - source_pos[0]
|
dx = target_pos[0] - source_pos[0]
|
||||||
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
|
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
|
||||||
index = get_child_index(
|
index = get_child_index(
|
||||||
router_method_name, listener_name, parent_children
|
router_method_name, listener_name, parent_children
|
||||||
)
|
)
|
||||||
edge_smooth = {
|
edge_smooth = {
|
||||||
"type": smooth_type,
|
"type": smooth_type,
|
||||||
"roundness": 0.2 + (0.1 * index),
|
"roundness": 0.2 + (0.1 * index),
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
edge_smooth = {"type": "cubicBezier"}
|
||||||
else:
|
else:
|
||||||
edge_smooth = {"type": "cubicBezier"}
|
edge_smooth = False
|
||||||
else:
|
|
||||||
edge_smooth = False
|
|
||||||
|
|
||||||
edge_style = {
|
edge_style = {
|
||||||
"color": colors["router_edge"],
|
"color": colors["router_edge"],
|
||||||
"width": 2,
|
"width": 2,
|
||||||
"arrows": "to",
|
"arrows": "to",
|
||||||
"dashes": True,
|
"dashes": True,
|
||||||
"smooth": edge_smooth,
|
"smooth": edge_smooth,
|
||||||
}
|
}
|
||||||
net.add_edge(router_method_name, listener_name, **edge_style)
|
net.add_edge(router_method_name, listener_name, **edge_style)
|
||||||
|
else:
|
||||||
|
# Same check here: known router edge and known method?
|
||||||
|
method_known = listener_name in flow._methods
|
||||||
|
if not method_known:
|
||||||
|
print(
|
||||||
|
f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge."
|
||||||
|
)
|
||||||
|
|||||||
@@ -14,13 +14,13 @@ class Knowledge(BaseModel):
|
|||||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||||
Args:
|
Args:
|
||||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||||
embedder_config: Optional[Dict[str, Any]] = None
|
embedder_config: Optional[Dict[str, Any]] = None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||||
embedder_config: Optional[Dict[str, Any]] = None
|
embedder_config: Optional[Dict[str, Any]] = None
|
||||||
collection_name: Optional[str] = None
|
collection_name: Optional[str] = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||||
@@ -14,17 +14,28 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
"""Base class for knowledge sources that load content from files."""
|
"""Base class for knowledge sources that load content from files."""
|
||||||
|
|
||||||
_logger: Logger = Logger(verbose=True)
|
_logger: Logger = Logger(verbose=True)
|
||||||
file_path: Union[Path, List[Path], str, List[str]] = Field(
|
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||||
..., description="The path to the file"
|
default=None,
|
||||||
|
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||||
|
)
|
||||||
|
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||||
|
default_factory=list, description="The path to the file"
|
||||||
)
|
)
|
||||||
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
||||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("file_path", "file_paths", mode="before")
|
||||||
|
def validate_file_path(cls, v, values):
|
||||||
|
"""Validate that at least one of file_path or file_paths is provided."""
|
||||||
|
if v is None and ("file_path" not in values or values.get("file_path") is None):
|
||||||
|
raise ValueError("Either file_path or file_paths must be provided")
|
||||||
|
return v
|
||||||
|
|
||||||
def model_post_init(self, _):
|
def model_post_init(self, _):
|
||||||
"""Post-initialization method to load content."""
|
"""Post-initialization method to load content."""
|
||||||
self.safe_file_paths = self._process_file_paths()
|
self.safe_file_paths = self._process_file_paths()
|
||||||
self.validate_paths()
|
self.validate_content()
|
||||||
self.content = self.load_content()
|
self.content = self.load_content()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -32,7 +43,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def validate_paths(self):
|
def validate_content(self):
|
||||||
"""Validate the paths."""
|
"""Validate the paths."""
|
||||||
for path in self.safe_file_paths:
|
for path in self.safe_file_paths:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
@@ -51,7 +62,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
|
|
||||||
def _save_documents(self):
|
def _save_documents(self):
|
||||||
"""Save the documents to the storage."""
|
"""Save the documents to the storage."""
|
||||||
self.storage.save(self.chunks)
|
if self.storage:
|
||||||
|
self.storage.save(self.chunks)
|
||||||
|
else:
|
||||||
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|
||||||
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
def convert_to_path(self, path: Union[Path, str]) -> Path:
|
||||||
"""Convert a path to a Path object."""
|
"""Convert a path to a Path object."""
|
||||||
@@ -59,13 +73,30 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
|
|
||||||
def _process_file_paths(self) -> List[Path]:
|
def _process_file_paths(self) -> List[Path]:
|
||||||
"""Convert file_path to a list of Path objects."""
|
"""Convert file_path to a list of Path objects."""
|
||||||
paths = (
|
|
||||||
[self.file_path]
|
if hasattr(self, "file_path") and self.file_path is not None:
|
||||||
if isinstance(self.file_path, (str, Path))
|
self._logger.log(
|
||||||
else self.file_path
|
"warning",
|
||||||
|
"The 'file_path' attribute is deprecated and will be removed in a future version. Please use 'file_paths' instead.",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
self.file_paths = self.file_path
|
||||||
|
|
||||||
|
if self.file_paths is None:
|
||||||
|
raise ValueError("Your source must be provided with a file_paths: []")
|
||||||
|
|
||||||
|
# Convert single path to list
|
||||||
|
path_list: List[Union[Path, str]] = (
|
||||||
|
[self.file_paths]
|
||||||
|
if isinstance(self.file_paths, (str, Path))
|
||||||
|
else list(self.file_paths)
|
||||||
|
if isinstance(self.file_paths, list)
|
||||||
|
else []
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(paths, list):
|
if not path_list:
|
||||||
raise ValueError("file_path must be a Path, str, or a list of these types")
|
raise ValueError(
|
||||||
|
"file_path/file_paths must be a Path, str, or a list of these types"
|
||||||
|
)
|
||||||
|
|
||||||
return [self.convert_to_path(path) for path in paths]
|
return [self.convert_to_path(path) for path in path_list]
|
||||||
|
|||||||
@@ -16,12 +16,12 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
|||||||
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
|
chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
|
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||||
collection_name: Optional[str] = Field(default=None)
|
collection_name: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_content(self) -> Dict[Any, str]:
|
def validate_content(self) -> Any:
|
||||||
"""Load and preprocess content from the source."""
|
"""Load and preprocess content from the source."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -46,4 +46,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
|||||||
Save the documents to the storage.
|
Save the documents to the storage.
|
||||||
This method should be called after the chunks and embeddings are generated.
|
This method should be called after the chunks and embeddings are generated.
|
||||||
"""
|
"""
|
||||||
self.storage.save(self.chunks)
|
if self.storage:
|
||||||
|
self.storage.save(self.chunks)
|
||||||
|
else:
|
||||||
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|||||||
120
src/crewai/knowledge/source/crew_docling_source.py
Normal file
120
src/crewai/knowledge/source/crew_docling_source.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator, List, Optional, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import InputFormat
|
||||||
|
from docling.document_converter import DocumentConverter
|
||||||
|
from docling.exceptions import ConversionError
|
||||||
|
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||||
|
from docling_core.types.doc.document import DoclingDocument
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
|
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||||
|
from crewai.utilities.logger import Logger
|
||||||
|
|
||||||
|
|
||||||
|
class CrewDoclingSource(BaseKnowledgeSource):
|
||||||
|
"""Default Source class for converting documents to markdown or json
|
||||||
|
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_logger: Logger = Logger(verbose=True)
|
||||||
|
|
||||||
|
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
|
||||||
|
file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||||
|
chunks: List[str] = Field(default_factory=list)
|
||||||
|
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
|
||||||
|
content: List[DoclingDocument] = Field(default_factory=list)
|
||||||
|
document_converter: DocumentConverter = Field(
|
||||||
|
default_factory=lambda: DocumentConverter(
|
||||||
|
allowed_formats=[
|
||||||
|
InputFormat.MD,
|
||||||
|
InputFormat.ASCIIDOC,
|
||||||
|
InputFormat.PDF,
|
||||||
|
InputFormat.DOCX,
|
||||||
|
InputFormat.HTML,
|
||||||
|
InputFormat.IMAGE,
|
||||||
|
InputFormat.XLSX,
|
||||||
|
InputFormat.PPTX,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def model_post_init(self, _) -> None:
|
||||||
|
if self.file_path:
|
||||||
|
self._logger.log(
|
||||||
|
"warning",
|
||||||
|
"The 'file_path' attribute is deprecated and will be removed in a future version. Please use 'file_paths' instead.",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
self.file_paths = self.file_path
|
||||||
|
self.safe_file_paths = self.validate_content()
|
||||||
|
self.content = self._load_content()
|
||||||
|
|
||||||
|
def _load_content(self) -> List[DoclingDocument]:
|
||||||
|
try:
|
||||||
|
return self._convert_source_to_docling_documents()
|
||||||
|
except ConversionError as e:
|
||||||
|
self._logger.log(
|
||||||
|
"error",
|
||||||
|
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.log("error", f"Error loading content: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def add(self) -> None:
|
||||||
|
if self.content is None:
|
||||||
|
return
|
||||||
|
for doc in self.content:
|
||||||
|
new_chunks_iterable = self._chunk_doc(doc)
|
||||||
|
self.chunks.extend(list(new_chunks_iterable))
|
||||||
|
self._save_documents()
|
||||||
|
|
||||||
|
def _convert_source_to_docling_documents(self) -> List[DoclingDocument]:
|
||||||
|
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||||
|
return [result.document for result in conv_results_iter]
|
||||||
|
|
||||||
|
def _chunk_doc(self, doc: DoclingDocument) -> Iterator[str]:
|
||||||
|
chunker = HierarchicalChunker()
|
||||||
|
for chunk in chunker.chunk(doc):
|
||||||
|
yield chunk.text
|
||||||
|
|
||||||
|
def validate_content(self) -> List[Union[Path, str]]:
|
||||||
|
processed_paths: List[Union[Path, str]] = []
|
||||||
|
for path in self.file_paths:
|
||||||
|
if isinstance(path, str):
|
||||||
|
if path.startswith(("http://", "https://")):
|
||||||
|
try:
|
||||||
|
if self._validate_url(path):
|
||||||
|
processed_paths.append(path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid URL format: {path}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
|
||||||
|
else:
|
||||||
|
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
|
||||||
|
if local_path.exists():
|
||||||
|
processed_paths.append(local_path)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"File not found: {local_path}")
|
||||||
|
else:
|
||||||
|
# this is an instance of Path
|
||||||
|
processed_paths.append(path)
|
||||||
|
return processed_paths
|
||||||
|
|
||||||
|
def _validate_url(self, url: str) -> bool:
|
||||||
|
try:
|
||||||
|
result = urlparse(url)
|
||||||
|
return all(
|
||||||
|
[
|
||||||
|
result.scheme in ("http", "https"),
|
||||||
|
result.netloc,
|
||||||
|
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
|
||||||
|
]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
@@ -13,9 +13,9 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
|||||||
|
|
||||||
def model_post_init(self, _):
|
def model_post_init(self, _):
|
||||||
"""Post-initialization method to validate content."""
|
"""Post-initialization method to validate content."""
|
||||||
self.load_content()
|
self.validate_content()
|
||||||
|
|
||||||
def load_content(self):
|
def validate_content(self):
|
||||||
"""Validate string content."""
|
"""Validate string content."""
|
||||||
if not isinstance(self.content, str):
|
if not isinstance(self.content, str):
|
||||||
raise ValueError("StringKnowledgeSource only accepts string content")
|
raise ValueError("StringKnowledgeSource only accepts string content")
|
||||||
|
|||||||
@@ -263,3 +263,62 @@ def test_flow_with_custom_state():
|
|||||||
flow = StateFlow()
|
flow = StateFlow()
|
||||||
flow.kickoff()
|
flow.kickoff()
|
||||||
assert flow.counter == 2
|
assert flow.counter == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_with_multiple_conditions():
|
||||||
|
"""Test a router that triggers when any of multiple steps complete (OR condition),
|
||||||
|
and another router that triggers only after all specified steps complete (AND condition).
|
||||||
|
"""
|
||||||
|
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class ComplexRouterFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def step_a(self):
|
||||||
|
execution_order.append("step_a")
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def step_b(self):
|
||||||
|
execution_order.append("step_b")
|
||||||
|
|
||||||
|
@router(or_("step_a", "step_b"))
|
||||||
|
def router_or(self):
|
||||||
|
execution_order.append("router_or")
|
||||||
|
return "next_step_or"
|
||||||
|
|
||||||
|
@listen("next_step_or")
|
||||||
|
def handle_next_step_or_event(self):
|
||||||
|
execution_order.append("handle_next_step_or_event")
|
||||||
|
|
||||||
|
@listen(handle_next_step_or_event)
|
||||||
|
def branch_2_step(self):
|
||||||
|
execution_order.append("branch_2_step")
|
||||||
|
|
||||||
|
@router(and_(handle_next_step_or_event, branch_2_step))
|
||||||
|
def router_and(self):
|
||||||
|
execution_order.append("router_and")
|
||||||
|
return "final_step"
|
||||||
|
|
||||||
|
@listen("final_step")
|
||||||
|
def log_final_step(self):
|
||||||
|
execution_order.append("log_final_step")
|
||||||
|
|
||||||
|
flow = ComplexRouterFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert "step_a" in execution_order
|
||||||
|
assert "step_b" in execution_order
|
||||||
|
assert "router_or" in execution_order
|
||||||
|
assert "handle_next_step_or_event" in execution_order
|
||||||
|
assert "branch_2_step" in execution_order
|
||||||
|
assert "router_and" in execution_order
|
||||||
|
assert "log_final_step" in execution_order
|
||||||
|
|
||||||
|
# Check that the AND router triggered after both relevant steps:
|
||||||
|
assert execution_order.index("router_and") > execution_order.index(
|
||||||
|
"handle_next_step_or_event"
|
||||||
|
)
|
||||||
|
assert execution_order.index("router_and") > execution_order.index("branch_2_step")
|
||||||
|
|
||||||
|
# final_step should run after router_and
|
||||||
|
assert execution_order.index("log_final_step") > execution_order.index("router_and")
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
"""Test Knowledge creation and querying functionality."""
|
"""Test Knowledge creation and querying functionality."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
|
||||||
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
||||||
from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource
|
from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource
|
||||||
from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource
|
from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource
|
||||||
@@ -200,7 +202,7 @@ def test_single_short_file(mock_vector_db, tmpdir):
|
|||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
file_source = TextFileKnowledgeSource(
|
file_source = TextFileKnowledgeSource(
|
||||||
file_path=file_path, metadata={"preference": "personal"}
|
file_paths=[file_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [file_source]
|
mock_vector_db.sources = [file_source]
|
||||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||||
@@ -242,7 +244,7 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
|
|||||||
f.write(content)
|
f.write(content)
|
||||||
|
|
||||||
file_source = TextFileKnowledgeSource(
|
file_source = TextFileKnowledgeSource(
|
||||||
file_path=file_path, metadata={"preference": "personal"}
|
file_paths=[file_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [file_source]
|
mock_vector_db.sources = [file_source]
|
||||||
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
|
||||||
@@ -279,7 +281,7 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
|
|||||||
file_paths.append((file_path, item["metadata"]))
|
file_paths.append((file_path, item["metadata"]))
|
||||||
|
|
||||||
file_sources = [
|
file_sources = [
|
||||||
TextFileKnowledgeSource(file_path=path, metadata=metadata)
|
TextFileKnowledgeSource(file_paths=[path], metadata=metadata)
|
||||||
for path, metadata in file_paths
|
for path, metadata in file_paths
|
||||||
]
|
]
|
||||||
mock_vector_db.sources = file_sources
|
mock_vector_db.sources = file_sources
|
||||||
@@ -352,7 +354,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
|
|||||||
file_paths.append(file_path)
|
file_paths.append(file_path)
|
||||||
|
|
||||||
file_sources = [
|
file_sources = [
|
||||||
TextFileKnowledgeSource(file_path=path, metadata={"preference": "personal"})
|
TextFileKnowledgeSource(file_paths=[path], metadata={"preference": "personal"})
|
||||||
for path in file_paths
|
for path in file_paths
|
||||||
]
|
]
|
||||||
mock_vector_db.sources = file_sources
|
mock_vector_db.sources = file_sources
|
||||||
@@ -399,7 +401,7 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
|
|||||||
file_paths.append(file_path)
|
file_paths.append(file_path)
|
||||||
|
|
||||||
file_sources = [
|
file_sources = [
|
||||||
TextFileKnowledgeSource(file_path=path, metadata={"preference": "personal"})
|
TextFileKnowledgeSource(file_paths=[path], metadata={"preference": "personal"})
|
||||||
for path in file_paths
|
for path in file_paths
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -424,7 +426,7 @@ def test_pdf_knowledge_source(mock_vector_db):
|
|||||||
|
|
||||||
# Create a PDFKnowledgeSource
|
# Create a PDFKnowledgeSource
|
||||||
pdf_source = PDFKnowledgeSource(
|
pdf_source = PDFKnowledgeSource(
|
||||||
file_path=pdf_path, metadata={"preference": "personal"}
|
file_paths=[pdf_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [pdf_source]
|
mock_vector_db.sources = [pdf_source]
|
||||||
mock_vector_db.query.return_value = [
|
mock_vector_db.query.return_value = [
|
||||||
@@ -461,7 +463,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
|
|||||||
|
|
||||||
# Create a CSVKnowledgeSource
|
# Create a CSVKnowledgeSource
|
||||||
csv_source = CSVKnowledgeSource(
|
csv_source = CSVKnowledgeSource(
|
||||||
file_path=csv_path, metadata={"preference": "personal"}
|
file_paths=[csv_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [csv_source]
|
mock_vector_db.sources = [csv_source]
|
||||||
mock_vector_db.query.return_value = [
|
mock_vector_db.query.return_value = [
|
||||||
@@ -496,7 +498,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
|
|||||||
|
|
||||||
# Create a JSONKnowledgeSource
|
# Create a JSONKnowledgeSource
|
||||||
json_source = JSONKnowledgeSource(
|
json_source = JSONKnowledgeSource(
|
||||||
file_path=json_path, metadata={"preference": "personal"}
|
file_paths=[json_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [json_source]
|
mock_vector_db.sources = [json_source]
|
||||||
mock_vector_db.query.return_value = [
|
mock_vector_db.query.return_value = [
|
||||||
@@ -529,7 +531,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
|||||||
|
|
||||||
# Create an ExcelKnowledgeSource
|
# Create an ExcelKnowledgeSource
|
||||||
excel_source = ExcelKnowledgeSource(
|
excel_source = ExcelKnowledgeSource(
|
||||||
file_path=excel_path, metadata={"preference": "personal"}
|
file_paths=[excel_path], metadata={"preference": "personal"}
|
||||||
)
|
)
|
||||||
mock_vector_db.sources = [excel_source]
|
mock_vector_db.sources = [excel_source]
|
||||||
mock_vector_db.query.return_value = [
|
mock_vector_db.query.return_value = [
|
||||||
@@ -543,3 +545,42 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
|
|||||||
# Assert that the correct information is retrieved
|
# Assert that the correct information is retrieved
|
||||||
assert any("30" in result["context"] for result in results)
|
assert any("30" in result["context"] for result in results)
|
||||||
mock_vector_db.query.assert_called_once()
|
mock_vector_db.query.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_docling_source(mock_vector_db):
|
||||||
|
docling_source = CrewDoclingSource(
|
||||||
|
file_paths=[
|
||||||
|
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
mock_vector_db.sources = [docling_source]
|
||||||
|
mock_vector_db.query.return_value = [
|
||||||
|
{
|
||||||
|
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
|
||||||
|
"score": 0.9,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
# Perform a query
|
||||||
|
query = "What is reward hacking?"
|
||||||
|
results = mock_vector_db.query(query)
|
||||||
|
assert any("reward hacking" in result["context"].lower() for result in results)
|
||||||
|
mock_vector_db.query.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_docling_sources():
|
||||||
|
urls: List[Union[Path, str]] = [
|
||||||
|
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
|
||||||
|
"https://lilianweng.github.io/posts/2024-07-07-hallucination/",
|
||||||
|
]
|
||||||
|
docling_source = CrewDoclingSource(file_paths=urls)
|
||||||
|
|
||||||
|
assert docling_source.file_paths == urls
|
||||||
|
assert docling_source.content is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_docling_source_with_local_file():
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
pdf_path = current_dir / "crewai_quickstart.pdf"
|
||||||
|
docling_source = CrewDoclingSource(file_paths=[pdf_path])
|
||||||
|
assert docling_source.file_paths == [pdf_path]
|
||||||
|
assert docling_source.content is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user