Compare commits

...

5 Commits

Author SHA1 Message Date
ericklima-ca
27472ba69e refactor: Change storage field to optional and improve error handling when saving documents 2024-12-26 22:27:19 -04:00
ericklima-ca
25aa774d8c fix: Change storage initialization to None for KnowledgeStorage 2024-12-26 21:30:06 -04:00
Brandon Hancock (bhancock_ai)
6cc2f510bf Feat/joao flow improvement requests (#1795)
* Add in or and and in router

* In the middle of improving plotting

* final plot changes

---------

Co-authored-by: João Moura <joaomdmoura@gmail.com>
2024-12-24 18:55:44 -03:00
Lorenze Jay
9a65abf6b8 removed some redundancies (#1796)
* removed some redundancies

* cleanup
2024-12-23 13:54:16 -05:00
Lorenze Jay
b3185ad90c Feat/docling-support (#1763)
* added tool for docling support

* docling support installation

* use file_paths instead of file_path

* fix import

* organized imports

* run_type docs

* needs to be list

* fixed logic

* logged but file_path is backwards compatible

* use file_paths instead of file_path 2

* added test for multiple sources for file_paths

* fix run-types

* enabling local files to work and type cleanup

* linted

* fix test and types

* fixed run types

* fix types

* renamed to CrewDoclingSource

* linted

* added docs

* resolve conflicts

---------

Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com>
Co-authored-by: Brandon Hancock <brandon@brandonhancock.io>
2024-12-23 13:19:58 -05:00
13 changed files with 1412 additions and 166 deletions

View File

@@ -79,6 +79,55 @@ crew = Crew(
result = crew.kickoff(inputs={"question": "What city does John live in and how old is he?"}) result = crew.kickoff(inputs={"question": "What city does John live in and how old is he?"})
``` ```
Here's another example with the `CrewDoclingSource`
```python Code
from crewai import LLM, Agent, Crew, Process, Task
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
# Create a knowledge source
content_source = CrewDoclingSource(
file_paths=[
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking",
"https://lilianweng.github.io/posts/2024-07-07-hallucination",
],
)
# Create an LLM with a temperature of 0 to ensure deterministic outputs
llm = LLM(model="gpt-4o-mini", temperature=0)
# Create an agent with the knowledge store
agent = Agent(
role="About papers",
goal="You know everything about the papers.",
backstory="""You are a master at understanding papers and their content.""",
verbose=True,
allow_delegation=False,
llm=llm,
)
task = Task(
description="Answer the following questions about the papers: {question}",
expected_output="An answer to the question.",
agent=agent,
)
crew = Crew(
agents=[agent],
tasks=[task],
verbose=True,
process=Process.sequential,
knowledge_sources=[
content_source
], # Enable knowledge by adding the sources here. You can also add more sources to the sources list.
)
result = crew.kickoff(
inputs={
"question": "What is the reward hacking paper about? Be sure to provide sources."
}
)
```
## Knowledge Configuration ## Knowledge Configuration
### Chunking Configuration ### Chunking Configuration

View File

@@ -51,6 +51,9 @@ openpyxl = [
"openpyxl>=3.1.5", "openpyxl>=3.1.5",
] ]
mem0 = ["mem0ai>=0.1.29"] mem0 = ["mem0ai>=0.1.29"]
docling = [
"docling>=2.12.0",
]
[tool.uv] [tool.uv]
dev-dependencies = [ dev-dependencies = [

View File

@@ -80,10 +80,27 @@ def listen(condition):
return decorator return decorator
def router(method): def router(condition):
def decorator(func): def decorator(func):
func.__is_router__ = True func.__is_router__ = True
func.__router_for__ = method.__name__ # Handle conditions like listen/start
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
else:
raise ValueError(
"Condition must be a method, string, or a result of or_() or and_()"
)
return func return func
return decorator return decorator
@@ -123,8 +140,8 @@ class FlowMeta(type):
start_methods = [] start_methods = []
listeners = {} listeners = {}
routers = {}
router_paths = {} router_paths = {}
routers = set()
for attr_name, attr_value in dct.items(): for attr_name, attr_value in dct.items():
if hasattr(attr_value, "__is_start_method__"): if hasattr(attr_value, "__is_start_method__"):
@@ -137,18 +154,11 @@ class FlowMeta(type):
methods = attr_value.__trigger_methods__ methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR") condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods) listeners[attr_name] = (condition_type, methods)
if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
elif hasattr(attr_value, "__is_router__"): routers.add(attr_name)
routers[attr_value.__router_for__] = attr_name possible_returns = get_possible_return_constants(attr_value)
possible_returns = get_possible_return_constants(attr_value) if possible_returns:
if possible_returns: router_paths[attr_name] = possible_returns
router_paths[attr_name] = possible_returns
# Register router as a listener to its triggering method
trigger_method_name = attr_value.__router_for__
methods = [trigger_method_name]
condition_type = "OR"
listeners[attr_name] = (condition_type, methods)
setattr(cls, "_start_methods", start_methods) setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners) setattr(cls, "_listeners", listeners)
@@ -163,7 +173,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
_start_methods: List[str] = [] _start_methods: List[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {} _listeners: Dict[str, tuple[str, List[str]]] = {}
_routers: Dict[str, str] = {} _routers: Set[str] = set()
_router_paths: Dict[str, List[str]] = {} _router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None initial_state: Union[Type[T], T, None] = None
event_emitter = Signal("event_emitter") event_emitter = Signal("event_emitter")
@@ -210,20 +220,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
return self._method_outputs return self._method_outputs
def _initialize_state(self, inputs: Dict[str, Any]) -> None: def _initialize_state(self, inputs: Dict[str, Any]) -> None:
"""
Initializes or updates the state with the provided inputs.
Args:
inputs: Dictionary of inputs to initialize or update the state.
Raises:
ValueError: If inputs do not match the structured state model.
TypeError: If state is neither a BaseModel instance nor a dictionary.
"""
if isinstance(self._state, BaseModel): if isinstance(self._state, BaseModel):
# Structured state management # Structured state
try: try:
# Define a function to create the dynamic class
def create_model_with_extra_forbid( def create_model_with_extra_forbid(
base_model: Type[BaseModel], base_model: Type[BaseModel],
) -> Type[BaseModel]: ) -> Type[BaseModel]:
@@ -233,34 +233,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
return ModelWithExtraForbid return ModelWithExtraForbid
# Create the dynamic class
ModelWithExtraForbid = create_model_with_extra_forbid( ModelWithExtraForbid = create_model_with_extra_forbid(
self._state.__class__ self._state.__class__
) )
# Create a new instance using the combined state and inputs
self._state = cast( self._state = cast(
T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs}) T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs})
) )
except ValidationError as e: except ValidationError as e:
raise ValueError(f"Invalid inputs for structured state: {e}") from e raise ValueError(f"Invalid inputs for structured state: {e}") from e
elif isinstance(self._state, dict): elif isinstance(self._state, dict):
# Unstructured state management
self._state.update(inputs) self._state.update(inputs)
else: else:
raise TypeError("State must be a BaseModel instance or a dictionary.") raise TypeError("State must be a BaseModel instance or a dictionary.")
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Starts the execution of the flow synchronously.
Args:
inputs: Optional dictionary of inputs to initialize or update the state.
Returns:
The final output from the flow execution.
"""
self.event_emitter.send( self.event_emitter.send(
self, self,
event=FlowStartedEvent( event=FlowStartedEvent(
@@ -274,15 +260,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
return asyncio.run(self.kickoff_async()) return asyncio.run(self.kickoff_async())
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any: async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Starts the execution of the flow asynchronously.
Args:
inputs: Optional dictionary of inputs to initialize or update the state.
Returns:
The final output from the flow execution.
"""
if not self._start_methods: if not self._start_methods:
raise ValueError("No start method defined") raise ValueError("No start method defined")
@@ -290,16 +267,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
self.__class__.__name__, list(self._methods.keys()) self.__class__.__name__, list(self._methods.keys())
) )
# Create tasks for all start methods
tasks = [ tasks = [
self._execute_start_method(start_method) self._execute_start_method(start_method)
for start_method in self._start_methods for start_method in self._start_methods
] ]
# Run all start methods concurrently
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
# Determine the final output (from the last executed method)
final_output = self._method_outputs[-1] if self._method_outputs else None final_output = self._method_outputs[-1] if self._method_outputs else None
self.event_emitter.send( self.event_emitter.send(
@@ -310,7 +283,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
result=final_output, result=final_output,
), ),
) )
return final_output return final_output
async def _execute_start_method(self, start_method_name: str) -> None: async def _execute_start_method(self, start_method_name: str) -> None:
@@ -327,49 +299,68 @@ class Flow(Generic[T], metaclass=FlowMeta):
if asyncio.iscoroutinefunction(method) if asyncio.iscoroutinefunction(method)
else method(*args, **kwargs) else method(*args, **kwargs)
) )
self._method_outputs.append(result) # Store the output self._method_outputs.append(result)
# Track method execution counts
self._method_execution_counts[method_name] = ( self._method_execution_counts[method_name] = (
self._method_execution_counts.get(method_name, 0) + 1 self._method_execution_counts.get(method_name, 0) + 1
) )
return result return result
async def _execute_listeners(self, trigger_method: str, result: Any) -> None: async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
listener_tasks = [] # First, handle routers repeatedly until no router triggers anymore
while True:
if trigger_method in self._routers: routers_triggered = self._find_triggered_methods(
router_method = self._methods[self._routers[trigger_method]] trigger_method, router_only=True
path = await self._execute_method(
self._routers[trigger_method], router_method
) )
trigger_method = path if not routers_triggered:
break
for router_name in routers_triggered:
await self._execute_single_listener(router_name, result)
# After executing router, the router's result is the path
# The last router executed sets the trigger_method
# The router result is the last element in self._method_outputs
trigger_method = self._method_outputs[-1]
# Now that no more routers are triggered by current trigger_method,
# execute normal listeners
listeners_triggered = self._find_triggered_methods(
trigger_method, router_only=False
)
if listeners_triggered:
tasks = [
self._execute_single_listener(listener_name, result)
for listener_name in listeners_triggered
]
await asyncio.gather(*tasks)
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> List[str]:
triggered = []
for listener_name, (condition_type, methods) in self._listeners.items(): for listener_name, (condition_type, methods) in self._listeners.items():
is_router = listener_name in self._routers
if router_only != is_router:
continue
if condition_type == "OR": if condition_type == "OR":
# If the trigger_method matches any in methods, run this
if trigger_method in methods: if trigger_method in methods:
# Schedule the listener without preventing re-execution triggered.append(listener_name)
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
elif condition_type == "AND": elif condition_type == "AND":
# Initialize pending methods for this listener if not already done # Initialize pending methods for this listener if not already done
if listener_name not in self._pending_and_listeners: if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods) self._pending_and_listeners[listener_name] = set(methods)
# Remove the trigger method from pending methods # Remove the trigger method from pending methods
self._pending_and_listeners[listener_name].discard(trigger_method) if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(trigger_method)
if not self._pending_and_listeners[listener_name]: if not self._pending_and_listeners[listener_name]:
# All required methods have been executed # All required methods have been executed
listener_tasks.append( triggered.append(listener_name)
self._execute_single_listener(listener_name, result)
)
# Reset pending methods for this listener # Reset pending methods for this listener
self._pending_and_listeners.pop(listener_name, None) self._pending_and_listeners.pop(listener_name, None)
# Run all listener tasks concurrently and wait for them to complete return triggered
if listener_tasks:
await asyncio.gather(*listener_tasks)
async def _execute_single_listener(self, listener_name: str, result: Any) -> None: async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
try: try:
@@ -386,17 +377,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
sig = inspect.signature(method) sig = inspect.signature(method)
params = list(sig.parameters.values()) params = list(sig.parameters.values())
# Exclude 'self' parameter
method_params = [p for p in params if p.name != "self"] method_params = [p for p in params if p.name != "self"]
if method_params: if method_params:
# If listener expects parameters, pass the result
listener_result = await self._execute_method( listener_result = await self._execute_method(
listener_name, method, result listener_name, method, result
) )
else: else:
# If listener does not expect parameters, call without arguments
listener_result = await self._execute_method(listener_name, method) listener_result = await self._execute_method(listener_name, method)
self.event_emitter.send( self.event_emitter.send(
@@ -408,8 +395,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
), ),
) )
# Execute listeners of this listener # Execute listeners (and possibly routers) of this listener
await self._execute_listeners(listener_name, listener_result) await self._execute_listeners(listener_name, listener_result)
except Exception as e: except Exception as e:
print( print(
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}" f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
@@ -422,5 +410,4 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._telemetry.flow_plotting_span( self._telemetry.flow_plotting_span(
self.__class__.__name__, list(self._methods.keys()) self.__class__.__name__, list(self._methods.keys())
) )
plot_flow(self, filename) plot_flow(self, filename)

View File

@@ -31,16 +31,50 @@ def get_possible_return_constants(function):
print(f"Source code:\n{source}") print(f"Source code:\n{source}")
return None return None
return_values = [] return_values = set()
dict_definitions = {}
class DictionaryAssignmentVisitor(ast.NodeVisitor):
def visit_Assign(self, node):
# Check if this assignment is assigning a dictionary literal to a variable
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
target = node.targets[0]
if isinstance(target, ast.Name):
var_name = target.id
dict_values = []
# Extract string values from the dictionary
for val in node.value.values:
if isinstance(val, ast.Constant) and isinstance(val.value, str):
dict_values.append(val.value)
# If non-string, skip or just ignore
if dict_values:
dict_definitions[var_name] = dict_values
self.generic_visit(node)
class ReturnVisitor(ast.NodeVisitor): class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node): def visit_Return(self, node):
# Check if the return value is a constant (Python 3.8+) # Direct string return
if isinstance(node.value, ast.Constant): if isinstance(node.value, ast.Constant) and isinstance(
return_values.append(node.value.value) node.value.value, str
):
return_values.add(node.value.value)
# Dictionary-based return, like return paths[result]
elif isinstance(node.value, ast.Subscript):
# Check if we're subscripting a known dictionary variable
if isinstance(node.value.value, ast.Name):
var_name = node.value.value.id
if var_name in dict_definitions:
# Add all possible dictionary values
for v in dict_definitions[var_name]:
return_values.add(v)
self.generic_visit(node)
# First pass: identify dictionary assignments
DictionaryAssignmentVisitor().visit(code_ast)
# Second pass: identify returns
ReturnVisitor().visit(code_ast) ReturnVisitor().visit(code_ast)
return return_values
return list(return_values) if return_values else None
def calculate_node_levels(flow): def calculate_node_levels(flow):
@@ -61,10 +95,7 @@ def calculate_node_levels(flow):
current_level = levels[current] current_level = levels[current]
visited.add(current) visited.add(current)
for listener_name, ( for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
condition_type,
trigger_methods,
) in flow._listeners.items():
if condition_type == "OR": if condition_type == "OR":
if current in trigger_methods: if current in trigger_methods:
if ( if (
@@ -89,7 +120,7 @@ def calculate_node_levels(flow):
queue.append(listener_name) queue.append(listener_name)
# Handle router connections # Handle router connections
if current in flow._routers.values(): if current in flow._routers:
router_method_name = current router_method_name = current
paths = flow._router_paths.get(router_method_name, []) paths = flow._router_paths.get(router_method_name, [])
for path in paths: for path in paths:
@@ -105,6 +136,7 @@ def calculate_node_levels(flow):
levels[listener_name] = current_level + 1 levels[listener_name] = current_level + 1
if listener_name not in visited: if listener_name not in visited:
queue.append(listener_name) queue.append(listener_name)
return levels return levels
@@ -142,7 +174,7 @@ def dfs_ancestors(node, ancestors, visited, flow):
dfs_ancestors(listener_name, ancestors, visited, flow) dfs_ancestors(listener_name, ancestors, visited, flow)
# Handle router methods separately # Handle router methods separately
if node in flow._routers.values(): if node in flow._routers:
router_method_name = node router_method_name = node
paths = flow._router_paths.get(router_method_name, []) paths = flow._router_paths.get(router_method_name, [])
for path in paths: for path in paths:

View File

@@ -94,12 +94,14 @@ def add_edges(net, flow, node_positions, colors):
ancestors = build_ancestor_dict(flow) ancestors = build_ancestor_dict(flow)
parent_children = build_parent_children_dict(flow) parent_children = build_parent_children_dict(flow)
# Edges for normal listeners
for method_name in flow._listeners: for method_name in flow._listeners:
condition_type, trigger_methods = flow._listeners[method_name] condition_type, trigger_methods = flow._listeners[method_name]
is_and_condition = condition_type == "AND" is_and_condition = condition_type == "AND"
for trigger in trigger_methods: for trigger in trigger_methods:
if trigger in flow._methods or trigger in flow._routers.values(): # Check if nodes exist before adding edges
if trigger in node_positions and method_name in node_positions:
is_router_edge = any( is_router_edge = any(
trigger in paths for paths in flow._router_paths.values() trigger in paths for paths in flow._router_paths.values()
) )
@@ -135,7 +137,22 @@ def add_edges(net, flow, node_positions, colors):
} }
net.add_edge(trigger, method_name, **edge_style) net.add_edge(trigger, method_name, **edge_style)
else:
# Nodes not found in node_positions. Check if it's a known router outcome and a known method.
is_router_edge = any(
trigger in paths for paths in flow._router_paths.values()
)
# Check if method_name is a known method
method_known = method_name in flow._methods
# If it's a known router edge and the method is known, don't warn.
# This means the path is legitimate, just not reflected as nodes here.
if not (is_router_edge and method_known):
print(
f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge."
)
# Edges for router return paths
for router_method_name, paths in flow._router_paths.items(): for router_method_name, paths in flow._router_paths.items():
for path in paths: for path in paths:
for listener_name, ( for listener_name, (
@@ -143,36 +160,49 @@ def add_edges(net, flow, node_positions, colors):
trigger_methods, trigger_methods,
) in flow._listeners.items(): ) in flow._listeners.items():
if path in trigger_methods: if path in trigger_methods:
is_cycle_edge = is_ancestor(trigger, method_name, ancestors) if (
parent_has_multiple_children = ( router_method_name in node_positions
len(parent_children.get(router_method_name, [])) > 1 and listener_name in node_positions
) ):
needs_curvature = is_cycle_edge or parent_has_multiple_children is_cycle_edge = is_ancestor(
router_method_name, listener_name, ancestors
)
parent_has_multiple_children = (
len(parent_children.get(router_method_name, [])) > 1
)
needs_curvature = is_cycle_edge or parent_has_multiple_children
if needs_curvature: if needs_curvature:
source_pos = node_positions.get(router_method_name) source_pos = node_positions.get(router_method_name)
target_pos = node_positions.get(listener_name) target_pos = node_positions.get(listener_name)
if source_pos and target_pos: if source_pos and target_pos:
dx = target_pos[0] - source_pos[0] dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW" smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index( index = get_child_index(
router_method_name, listener_name, parent_children router_method_name, listener_name, parent_children
) )
edge_smooth = { edge_smooth = {
"type": smooth_type, "type": smooth_type,
"roundness": 0.2 + (0.1 * index), "roundness": 0.2 + (0.1 * index),
} }
else:
edge_smooth = {"type": "cubicBezier"}
else: else:
edge_smooth = {"type": "cubicBezier"} edge_smooth = False
else:
edge_smooth = False
edge_style = { edge_style = {
"color": colors["router_edge"], "color": colors["router_edge"],
"width": 2, "width": 2,
"arrows": "to", "arrows": "to",
"dashes": True, "dashes": True,
"smooth": edge_smooth, "smooth": edge_smooth,
} }
net.add_edge(router_method_name, listener_name, **edge_style) net.add_edge(router_method_name, listener_name, **edge_style)
else:
# Same check here: known router edge and known method?
method_known = listener_name in flow._methods
if not method_known:
print(
f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge."
)

View File

@@ -14,13 +14,13 @@ class Knowledge(BaseModel):
Knowledge is a collection of sources and setup for the vector store to save and query relevant context. Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args: Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage) storage: Optional[KnowledgeStorage] = Field(default=None)
embedder_config: Optional[Dict[str, Any]] = None embedder_config: Optional[Dict[str, Any]] = None
""" """
sources: List[BaseKnowledgeSource] = Field(default_factory=list) sources: List[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage) storage: Optional[KnowledgeStorage] = Field(default=None)
embedder_config: Optional[Dict[str, Any]] = None embedder_config: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None collection_name: Optional[str] = None

View File

@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Dict, List, Union from typing import Dict, List, Optional, Union
from pydantic import Field from pydantic import Field, field_validator
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
@@ -14,17 +14,28 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Base class for knowledge sources that load content from files.""" """Base class for knowledge sources that load content from files."""
_logger: Logger = Logger(verbose=True) _logger: Logger = Logger(verbose=True)
file_path: Union[Path, List[Path], str, List[str]] = Field( file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
..., description="The path to the file" default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
)
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default_factory=list, description="The path to the file"
) )
content: Dict[Path, str] = Field(init=False, default_factory=dict) content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage) storage: Optional[KnowledgeStorage] = Field(default=None)
safe_file_paths: List[Path] = Field(default_factory=list) safe_file_paths: List[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, values):
"""Validate that at least one of file_path or file_paths is provided."""
if v is None and ("file_path" not in values or values.get("file_path") is None):
raise ValueError("Either file_path or file_paths must be provided")
return v
def model_post_init(self, _): def model_post_init(self, _):
"""Post-initialization method to load content.""" """Post-initialization method to load content."""
self.safe_file_paths = self._process_file_paths() self.safe_file_paths = self._process_file_paths()
self.validate_paths() self.validate_content()
self.content = self.load_content() self.content = self.load_content()
@abstractmethod @abstractmethod
@@ -32,7 +43,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory.""" """Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
pass pass
def validate_paths(self): def validate_content(self):
"""Validate the paths.""" """Validate the paths."""
for path in self.safe_file_paths: for path in self.safe_file_paths:
if not path.exists(): if not path.exists():
@@ -51,7 +62,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def _save_documents(self): def _save_documents(self):
"""Save the documents to the storage.""" """Save the documents to the storage."""
self.storage.save(self.chunks) if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
def convert_to_path(self, path: Union[Path, str]) -> Path: def convert_to_path(self, path: Union[Path, str]) -> Path:
"""Convert a path to a Path object.""" """Convert a path to a Path object."""
@@ -59,13 +73,30 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def _process_file_paths(self) -> List[Path]: def _process_file_paths(self) -> List[Path]:
"""Convert file_path to a list of Path objects.""" """Convert file_path to a list of Path objects."""
paths = (
[self.file_path] if hasattr(self, "file_path") and self.file_path is not None:
if isinstance(self.file_path, (str, Path)) self._logger.log(
else self.file_path "warning",
"The 'file_path' attribute is deprecated and will be removed in a future version. Please use 'file_paths' instead.",
color="yellow",
)
self.file_paths = self.file_path
if self.file_paths is None:
raise ValueError("Your source must be provided with a file_paths: []")
# Convert single path to list
path_list: List[Union[Path, str]] = (
[self.file_paths]
if isinstance(self.file_paths, (str, Path))
else list(self.file_paths)
if isinstance(self.file_paths, list)
else []
) )
if not isinstance(paths, list): if not path_list:
raise ValueError("file_path must be a Path, str, or a list of these types") raise ValueError(
"file_path/file_paths must be a Path, str, or a list of these types"
)
return [self.convert_to_path(path) for path in paths] return [self.convert_to_path(path) for path in path_list]

View File

@@ -16,12 +16,12 @@ class BaseKnowledgeSource(BaseModel, ABC):
chunk_embeddings: List[np.ndarray] = Field(default_factory=list) chunk_embeddings: List[np.ndarray] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage) storage: Optional[KnowledgeStorage] = Field(default=None)
metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused metadata: Dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: Optional[str] = Field(default=None) collection_name: Optional[str] = Field(default=None)
@abstractmethod @abstractmethod
def load_content(self) -> Dict[Any, str]: def validate_content(self) -> Any:
"""Load and preprocess content from the source.""" """Load and preprocess content from the source."""
pass pass
@@ -46,4 +46,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
Save the documents to the storage. Save the documents to the storage.
This method should be called after the chunks and embeddings are generated. This method should be called after the chunks and embeddings are generated.
""" """
self.storage.save(self.chunks) if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")

View File

@@ -0,0 +1,120 @@
from pathlib import Path
from typing import Iterator, List, Optional, Union
from urllib.parse import urlparse
from docling.datamodel.base_models import InputFormat
from docling.document_converter import DocumentConverter
from docling.exceptions import ConversionError
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
from docling_core.types.doc.document import DoclingDocument
from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger
class CrewDoclingSource(BaseKnowledgeSource):
"""Default Source class for converting documents to markdown or json
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
"""
_logger: Logger = Logger(verbose=True)
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
file_paths: List[Union[Path, str]] = Field(default_factory=list)
chunks: List[str] = Field(default_factory=list)
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
content: List[DoclingDocument] = Field(default_factory=list)
document_converter: DocumentConverter = Field(
default_factory=lambda: DocumentConverter(
allowed_formats=[
InputFormat.MD,
InputFormat.ASCIIDOC,
InputFormat.PDF,
InputFormat.DOCX,
InputFormat.HTML,
InputFormat.IMAGE,
InputFormat.XLSX,
InputFormat.PPTX,
]
)
)
def model_post_init(self, _) -> None:
if self.file_path:
self._logger.log(
"warning",
"The 'file_path' attribute is deprecated and will be removed in a future version. Please use 'file_paths' instead.",
color="yellow",
)
self.file_paths = self.file_path
self.safe_file_paths = self.validate_content()
self.content = self._load_content()
def _load_content(self) -> List[DoclingDocument]:
try:
return self._convert_source_to_docling_documents()
except ConversionError as e:
self._logger.log(
"error",
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
"red",
)
raise e
except Exception as e:
self._logger.log("error", f"Error loading content: {e}")
raise e
def add(self) -> None:
if self.content is None:
return
for doc in self.content:
new_chunks_iterable = self._chunk_doc(doc)
self.chunks.extend(list(new_chunks_iterable))
self._save_documents()
def _convert_source_to_docling_documents(self) -> List[DoclingDocument]:
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
return [result.document for result in conv_results_iter]
def _chunk_doc(self, doc: DoclingDocument) -> Iterator[str]:
chunker = HierarchicalChunker()
for chunk in chunker.chunk(doc):
yield chunk.text
def validate_content(self) -> List[Union[Path, str]]:
processed_paths: List[Union[Path, str]] = []
for path in self.file_paths:
if isinstance(path, str):
if path.startswith(("http://", "https://")):
try:
if self._validate_url(path):
processed_paths.append(path)
else:
raise ValueError(f"Invalid URL format: {path}")
except Exception as e:
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
else:
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
if local_path.exists():
processed_paths.append(local_path)
else:
raise FileNotFoundError(f"File not found: {local_path}")
else:
# this is an instance of Path
processed_paths.append(path)
return processed_paths
def _validate_url(self, url: str) -> bool:
try:
result = urlparse(url)
return all(
[
result.scheme in ("http", "https"),
result.netloc,
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
]
)
except Exception:
return False

View File

@@ -13,9 +13,9 @@ class StringKnowledgeSource(BaseKnowledgeSource):
def model_post_init(self, _): def model_post_init(self, _):
"""Post-initialization method to validate content.""" """Post-initialization method to validate content."""
self.load_content() self.validate_content()
def load_content(self): def validate_content(self):
"""Validate string content.""" """Validate string content."""
if not isinstance(self.content, str): if not isinstance(self.content, str):
raise ValueError("StringKnowledgeSource only accepts string content") raise ValueError("StringKnowledgeSource only accepts string content")

View File

@@ -263,3 +263,62 @@ def test_flow_with_custom_state():
flow = StateFlow() flow = StateFlow()
flow.kickoff() flow.kickoff()
assert flow.counter == 2 assert flow.counter == 2
def test_router_with_multiple_conditions():
"""Test a router that triggers when any of multiple steps complete (OR condition),
and another router that triggers only after all specified steps complete (AND condition).
"""
execution_order = []
class ComplexRouterFlow(Flow):
@start()
def step_a(self):
execution_order.append("step_a")
@start()
def step_b(self):
execution_order.append("step_b")
@router(or_("step_a", "step_b"))
def router_or(self):
execution_order.append("router_or")
return "next_step_or"
@listen("next_step_or")
def handle_next_step_or_event(self):
execution_order.append("handle_next_step_or_event")
@listen(handle_next_step_or_event)
def branch_2_step(self):
execution_order.append("branch_2_step")
@router(and_(handle_next_step_or_event, branch_2_step))
def router_and(self):
execution_order.append("router_and")
return "final_step"
@listen("final_step")
def log_final_step(self):
execution_order.append("log_final_step")
flow = ComplexRouterFlow()
flow.kickoff()
assert "step_a" in execution_order
assert "step_b" in execution_order
assert "router_or" in execution_order
assert "handle_next_step_or_event" in execution_order
assert "branch_2_step" in execution_order
assert "router_and" in execution_order
assert "log_final_step" in execution_order
# Check that the AND router triggered after both relevant steps:
assert execution_order.index("router_and") > execution_order.index(
"handle_next_step_or_event"
)
assert execution_order.index("router_and") > execution_order.index("branch_2_step")
# final_step should run after router_and
assert execution_order.index("log_final_step") > execution_order.index("router_and")

View File

@@ -1,10 +1,12 @@
"""Test Knowledge creation and querying functionality.""" """Test Knowledge creation and querying functionality."""
from pathlib import Path from pathlib import Path
from typing import List, Union
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource
from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource
@@ -200,7 +202,7 @@ def test_single_short_file(mock_vector_db, tmpdir):
f.write(content) f.write(content)
file_source = TextFileKnowledgeSource( file_source = TextFileKnowledgeSource(
file_path=file_path, metadata={"preference": "personal"} file_paths=[file_path], metadata={"preference": "personal"}
) )
mock_vector_db.sources = [file_source] mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
@@ -242,7 +244,7 @@ def test_single_2k_character_file(mock_vector_db, tmpdir):
f.write(content) f.write(content)
file_source = TextFileKnowledgeSource( file_source = TextFileKnowledgeSource(
file_path=file_path, metadata={"preference": "personal"} file_paths=[file_path], metadata={"preference": "personal"}
) )
mock_vector_db.sources = [file_source] mock_vector_db.sources = [file_source]
mock_vector_db.query.return_value = [{"context": content, "score": 0.9}] mock_vector_db.query.return_value = [{"context": content, "score": 0.9}]
@@ -279,7 +281,7 @@ def test_multiple_short_files(mock_vector_db, tmpdir):
file_paths.append((file_path, item["metadata"])) file_paths.append((file_path, item["metadata"]))
file_sources = [ file_sources = [
TextFileKnowledgeSource(file_path=path, metadata=metadata) TextFileKnowledgeSource(file_paths=[path], metadata=metadata)
for path, metadata in file_paths for path, metadata in file_paths
] ]
mock_vector_db.sources = file_sources mock_vector_db.sources = file_sources
@@ -352,7 +354,7 @@ def test_multiple_2k_character_files(mock_vector_db, tmpdir):
file_paths.append(file_path) file_paths.append(file_path)
file_sources = [ file_sources = [
TextFileKnowledgeSource(file_path=path, metadata={"preference": "personal"}) TextFileKnowledgeSource(file_paths=[path], metadata={"preference": "personal"})
for path in file_paths for path in file_paths
] ]
mock_vector_db.sources = file_sources mock_vector_db.sources = file_sources
@@ -399,7 +401,7 @@ def test_hybrid_string_and_files(mock_vector_db, tmpdir):
file_paths.append(file_path) file_paths.append(file_path)
file_sources = [ file_sources = [
TextFileKnowledgeSource(file_path=path, metadata={"preference": "personal"}) TextFileKnowledgeSource(file_paths=[path], metadata={"preference": "personal"})
for path in file_paths for path in file_paths
] ]
@@ -424,7 +426,7 @@ def test_pdf_knowledge_source(mock_vector_db):
# Create a PDFKnowledgeSource # Create a PDFKnowledgeSource
pdf_source = PDFKnowledgeSource( pdf_source = PDFKnowledgeSource(
file_path=pdf_path, metadata={"preference": "personal"} file_paths=[pdf_path], metadata={"preference": "personal"}
) )
mock_vector_db.sources = [pdf_source] mock_vector_db.sources = [pdf_source]
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
@@ -461,7 +463,7 @@ def test_csv_knowledge_source(mock_vector_db, tmpdir):
# Create a CSVKnowledgeSource # Create a CSVKnowledgeSource
csv_source = CSVKnowledgeSource( csv_source = CSVKnowledgeSource(
file_path=csv_path, metadata={"preference": "personal"} file_paths=[csv_path], metadata={"preference": "personal"}
) )
mock_vector_db.sources = [csv_source] mock_vector_db.sources = [csv_source]
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
@@ -496,7 +498,7 @@ def test_json_knowledge_source(mock_vector_db, tmpdir):
# Create a JSONKnowledgeSource # Create a JSONKnowledgeSource
json_source = JSONKnowledgeSource( json_source = JSONKnowledgeSource(
file_path=json_path, metadata={"preference": "personal"} file_paths=[json_path], metadata={"preference": "personal"}
) )
mock_vector_db.sources = [json_source] mock_vector_db.sources = [json_source]
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
@@ -529,7 +531,7 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
# Create an ExcelKnowledgeSource # Create an ExcelKnowledgeSource
excel_source = ExcelKnowledgeSource( excel_source = ExcelKnowledgeSource(
file_path=excel_path, metadata={"preference": "personal"} file_paths=[excel_path], metadata={"preference": "personal"}
) )
mock_vector_db.sources = [excel_source] mock_vector_db.sources = [excel_source]
mock_vector_db.query.return_value = [ mock_vector_db.query.return_value = [
@@ -543,3 +545,42 @@ def test_excel_knowledge_source(mock_vector_db, tmpdir):
# Assert that the correct information is retrieved # Assert that the correct information is retrieved
assert any("30" in result["context"] for result in results) assert any("30" in result["context"] for result in results)
mock_vector_db.query.assert_called_once() mock_vector_db.query.assert_called_once()
def test_docling_source(mock_vector_db):
docling_source = CrewDoclingSource(
file_paths=[
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
],
)
mock_vector_db.sources = [docling_source]
mock_vector_db.query.return_value = [
{
"context": "Reward hacking is a technique used to improve the performance of reinforcement learning agents.",
"score": 0.9,
}
]
# Perform a query
query = "What is reward hacking?"
results = mock_vector_db.query(query)
assert any("reward hacking" in result["context"].lower() for result in results)
mock_vector_db.query.assert_called_once()
def test_multiple_docling_sources():
urls: List[Union[Path, str]] = [
"https://lilianweng.github.io/posts/2024-11-28-reward-hacking/",
"https://lilianweng.github.io/posts/2024-07-07-hallucination/",
]
docling_source = CrewDoclingSource(file_paths=urls)
assert docling_source.file_paths == urls
assert docling_source.content is not None
def test_docling_source_with_local_file():
current_dir = Path(__file__).parent
pdf_path = current_dir / "crewai_quickstart.pdf"
docling_source = CrewDoclingSource(file_paths=[pdf_path])
assert docling_source.file_paths == [pdf_path]
assert docling_source.content is not None

911
uv.lock generated

File diff suppressed because it is too large Load Diff