mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
Adding Autocomplete to OSS (#1198)
* Cleaned up model_config * Fix pydantic issues * 99% done with autocomplete * fixed test issues * Fix type checking issues
This commit is contained in:
committed by
GitHub
parent
3451b6fc7a
commit
bf7372fefa
@@ -1,32 +1,26 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Generic, Tuple, TypeVar
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
T = TypeVar("T", bound=Dict[str, Any])
|
||||
U = TypeVar("U")
|
||||
|
||||
class Route(BaseModel):
|
||||
condition: Callable[[Dict[str, Any]], bool]
|
||||
pipeline: Any
|
||||
|
||||
|
||||
class Route(Generic[T, U]):
|
||||
condition: Callable[[T], bool]
|
||||
pipeline: U
|
||||
|
||||
def __init__(self, condition: Callable[[T], bool], pipeline: U):
|
||||
self.condition = condition
|
||||
self.pipeline = pipeline
|
||||
|
||||
|
||||
class Router(BaseModel, Generic[T, U]):
|
||||
routes: Dict[str, Route[T, U]] = Field(
|
||||
class Router(BaseModel):
|
||||
routes: Dict[str, Route] = Field(
|
||||
default_factory=dict,
|
||||
description="Dictionary of route names to (condition, pipeline) tuples",
|
||||
)
|
||||
default: U = Field(..., description="Default pipeline if no conditions are met")
|
||||
default: Any = Field(..., description="Default pipeline if no conditions are met")
|
||||
_route_types: Dict[str, type] = PrivateAttr(default_factory=dict)
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, routes: Dict[str, Route[T, U]], default: U, **data):
|
||||
def __init__(self, routes: Dict[str, Route], default: Any, **data):
|
||||
super().__init__(routes=routes, default=default, **data)
|
||||
self._check_copyable(default)
|
||||
for name, route in routes.items():
|
||||
@@ -34,16 +28,16 @@ class Router(BaseModel, Generic[T, U]):
|
||||
self._route_types[name] = type(route.pipeline)
|
||||
|
||||
@staticmethod
|
||||
def _check_copyable(obj):
|
||||
def _check_copyable(obj: Any) -> None:
|
||||
if not hasattr(obj, "copy") or not callable(getattr(obj, "copy")):
|
||||
raise ValueError(f"Object of type {type(obj)} must have a 'copy' method")
|
||||
|
||||
def add_route(
|
||||
self,
|
||||
name: str,
|
||||
condition: Callable[[T], bool],
|
||||
pipeline: U,
|
||||
) -> "Router[T, U]":
|
||||
condition: Callable[[Dict[str, Any]], bool],
|
||||
pipeline: Any,
|
||||
) -> "Router":
|
||||
"""
|
||||
Add a named route with its condition and corresponding pipeline to the router.
|
||||
|
||||
@@ -60,7 +54,7 @@ class Router(BaseModel, Generic[T, U]):
|
||||
self._route_types[name] = type(pipeline)
|
||||
return self
|
||||
|
||||
def route(self, input_data: T) -> Tuple[U, str]:
|
||||
def route(self, input_data: Dict[str, Any]) -> Tuple[Any, str]:
|
||||
"""
|
||||
Evaluate the input against the conditions and return the appropriate pipeline.
|
||||
|
||||
@@ -76,15 +70,15 @@ class Router(BaseModel, Generic[T, U]):
|
||||
|
||||
return self.default, "default"
|
||||
|
||||
def copy(self) -> "Router[T, U]":
|
||||
def copy(self) -> "Router":
|
||||
"""Create a deep copy of the Router."""
|
||||
new_routes = {
|
||||
name: Route(
|
||||
condition=deepcopy(route.condition),
|
||||
pipeline=route.pipeline.copy(), # type: ignore
|
||||
pipeline=route.pipeline.copy(),
|
||||
)
|
||||
for name, route in self.routes.items()
|
||||
}
|
||||
new_default = self.default.copy() # type: ignore
|
||||
new_default = self.default.copy()
|
||||
|
||||
return Router(routes=new_routes, default=new_default)
|
||||
|
||||
Reference in New Issue
Block a user