mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: improve data training for models up to 7B parameters (#3085)
* feat: improve data training for models up to 7B parameters. * docs: training considerations for small models to the documentation
This commit is contained in:
@@ -6,10 +6,10 @@ icon: dumbbell
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
The training feature in CrewAI allows you to train your AI agents using the command-line interface (CLI).
|
The training feature in CrewAI allows you to train your AI agents using the command-line interface (CLI).
|
||||||
By running the command `crewai train -n <n_iterations>`, you can specify the number of iterations for the training process.
|
By running the command `crewai train -n <n_iterations>`, you can specify the number of iterations for the training process.
|
||||||
|
|
||||||
During training, CrewAI utilizes techniques to optimize the performance of your agents along with human feedback.
|
During training, CrewAI utilizes techniques to optimize the performance of your agents along with human feedback.
|
||||||
This helps the agents improve their understanding, decision-making, and problem-solving abilities.
|
This helps the agents improve their understanding, decision-making, and problem-solving abilities.
|
||||||
|
|
||||||
### Training Your Crew Using the CLI
|
### Training Your Crew Using the CLI
|
||||||
@@ -42,8 +42,8 @@ filename = "your_model.pkl"
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
YourCrewName_Crew().crew().train(
|
YourCrewName_Crew().crew().train(
|
||||||
n_iterations=n_iterations,
|
n_iterations=n_iterations,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
filename=filename
|
filename=filename
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -64,4 +64,68 @@ Once the training is complete, your agents will be equipped with enhanced capabi
|
|||||||
Remember to regularly update and retrain your agents to ensure they stay up-to-date with the latest information and advancements in the field.
|
Remember to regularly update and retrain your agents to ensure they stay up-to-date with the latest information and advancements in the field.
|
||||||
|
|
||||||
Happy training with CrewAI! 🚀
|
Happy training with CrewAI! 🚀
|
||||||
|
|
||||||
|
## Small Language Model Considerations
|
||||||
|
|
||||||
|
<Warning>
|
||||||
|
When using smaller language models (≤7B parameters) for training data evaluation, be aware that they may face challenges with generating structured outputs and following complex instructions.
|
||||||
|
</Warning>
|
||||||
|
|
||||||
|
### Limitations of Small Models in Training Evaluation
|
||||||
|
|
||||||
|
<CardGroup cols={2}>
|
||||||
|
<Card title="JSON Output Accuracy" icon="triangle-exclamation">
|
||||||
|
Smaller models often struggle with producing valid JSON responses needed for structured training evaluations, leading to parsing errors and incomplete data.
|
||||||
|
</Card>
|
||||||
|
<Card title="Evaluation Quality" icon="chart-line">
|
||||||
|
Models under 7B parameters may provide less nuanced evaluations with limited reasoning depth compared to larger models.
|
||||||
|
</Card>
|
||||||
|
<Card title="Instruction Following" icon="list-check">
|
||||||
|
Complex training evaluation criteria may not be fully followed or considered by smaller models.
|
||||||
|
</Card>
|
||||||
|
<Card title="Consistency" icon="rotate">
|
||||||
|
Evaluations across multiple training iterations may lack consistency with smaller models.
|
||||||
|
</Card>
|
||||||
|
</CardGroup>
|
||||||
|
|
||||||
|
### Recommendations for Training
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<Tab title="Best Practice">
|
||||||
|
For optimal training quality and reliable evaluations, we strongly recommend using models with at least 7B parameters or larger:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Crew, Task, LLM
|
||||||
|
|
||||||
|
# Recommended minimum for training evaluation
|
||||||
|
llm = LLM(model="mistral/open-mistral-7b")
|
||||||
|
|
||||||
|
# Better options for reliable training evaluation
|
||||||
|
llm = LLM(model="anthropic/claude-3-sonnet-20240229-v1:0")
|
||||||
|
llm = LLM(model="gpt-4o")
|
||||||
|
|
||||||
|
# Use this LLM with your agents
|
||||||
|
agent = Agent(
|
||||||
|
role="Training Evaluator",
|
||||||
|
goal="Provide accurate training feedback",
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
More powerful models provide higher quality feedback with better reasoning, leading to more effective training iterations.
|
||||||
|
</Tip>
|
||||||
|
</Tab>
|
||||||
|
<Tab title="Small Model Usage">
|
||||||
|
If you must use smaller models for training evaluation, be aware of these constraints:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Using a smaller model (expect some limitations)
|
||||||
|
llm = LLM(model="huggingface/microsoft/Phi-3-mini-4k-instruct")
|
||||||
|
```
|
||||||
|
|
||||||
|
<Warning>
|
||||||
|
While CrewAI includes optimizations for small models, expect less reliable and less nuanced evaluation results that may require more human intervention during training.
|
||||||
|
</Warning>
|
||||||
|
</Tab>
|
||||||
|
</Tabs>
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
|
|||||||
from crewai.utilities import Converter
|
from crewai.utilities import Converter
|
||||||
from crewai.utilities.events import TaskEvaluationEvent, crewai_event_bus
|
from crewai.utilities.events import TaskEvaluationEvent, crewai_event_bus
|
||||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||||
|
from crewai.utilities.training_converter import TrainingConverter
|
||||||
|
|
||||||
|
|
||||||
class Entity(BaseModel):
|
class Entity(BaseModel):
|
||||||
@@ -133,7 +134,7 @@ class TaskEvaluator:
|
|||||||
).get_schema()
|
).get_schema()
|
||||||
instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}"
|
instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}"
|
||||||
|
|
||||||
converter = Converter(
|
converter = TrainingConverter(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
text=evaluation_query,
|
text=evaluation_query,
|
||||||
model=TrainingTaskEvaluation,
|
model=TrainingTaskEvaluation,
|
||||||
|
|||||||
89
src/crewai/utilities/training_converter.py
Normal file
89
src/crewai/utilities/training_converter.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Any, get_origin
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from crewai.utilities.converter import Converter, ConverterError
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingConverter(Converter):
|
||||||
|
"""
|
||||||
|
A specialized converter for smaller LLMs (up to 7B parameters) that handles validation errors
|
||||||
|
by breaking down the model into individual fields and querying the LLM for each field separately.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def to_pydantic(self, current_attempt=1) -> BaseModel:
|
||||||
|
try:
|
||||||
|
return super().to_pydantic(current_attempt)
|
||||||
|
except ConverterError:
|
||||||
|
return self._convert_field_by_field()
|
||||||
|
|
||||||
|
def _convert_field_by_field(self) -> BaseModel:
|
||||||
|
field_values = {}
|
||||||
|
|
||||||
|
for field_name, field_info in self.model.model_fields.items():
|
||||||
|
field_description = field_info.description
|
||||||
|
field_type = field_info.annotation
|
||||||
|
|
||||||
|
response = self._ask_llm_for_field(field_name, field_description)
|
||||||
|
value = self._process_field_value(response, field_type)
|
||||||
|
field_values[field_name] = value
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.model(**field_values)
|
||||||
|
except ValidationError as e:
|
||||||
|
raise ConverterError(f"Failed to create model from individually collected fields: {e}")
|
||||||
|
|
||||||
|
def _ask_llm_for_field(self, field_name: str, field_description: str) -> str:
|
||||||
|
prompt = f"""
|
||||||
|
Based on the following information:
|
||||||
|
{self.text}
|
||||||
|
|
||||||
|
Please provide ONLY the {field_name} field value as described:
|
||||||
|
"{field_description}"
|
||||||
|
|
||||||
|
Respond with ONLY the requested information, nothing else.
|
||||||
|
"""
|
||||||
|
return self.llm.call([
|
||||||
|
{"role": "system", "content": f"Extract the {field_name} from the previous information."},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
])
|
||||||
|
|
||||||
|
def _process_field_value(self, response: str, field_type: Any) -> Any:
|
||||||
|
response = response.strip()
|
||||||
|
origin = get_origin(field_type)
|
||||||
|
|
||||||
|
if origin is list:
|
||||||
|
return self._parse_list(response)
|
||||||
|
|
||||||
|
if field_type is float:
|
||||||
|
return self._parse_float(response)
|
||||||
|
|
||||||
|
if field_type is str:
|
||||||
|
return response
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _parse_list(self, response: str) -> list:
|
||||||
|
try:
|
||||||
|
if response.startswith('['):
|
||||||
|
return json.loads(response)
|
||||||
|
|
||||||
|
items = [item.strip() for item in response.split('\n') if item.strip()]
|
||||||
|
return [self._strip_bullet(item) for item in items]
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return [response]
|
||||||
|
|
||||||
|
def _parse_float(self, response: str) -> float:
|
||||||
|
try:
|
||||||
|
match = re.search(r'(\d+(\.\d+)?)', response)
|
||||||
|
return float(match.group(1)) if match else 0.0
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def _strip_bullet(self, item: str) -> str:
|
||||||
|
if item.startswith(('- ', '* ')):
|
||||||
|
return item[2:].strip()
|
||||||
|
return item.strip()
|
||||||
@@ -1,13 +1,15 @@
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
from crewai.utilities.evaluators.task_evaluator import (
|
from crewai.utilities.evaluators.task_evaluator import (
|
||||||
TaskEvaluator,
|
TaskEvaluator,
|
||||||
TrainingTaskEvaluation,
|
TrainingTaskEvaluation,
|
||||||
)
|
)
|
||||||
|
from crewai.utilities.converter import ConverterError
|
||||||
|
|
||||||
|
|
||||||
@patch("crewai.utilities.evaluators.task_evaluator.Converter")
|
@patch("crewai.utilities.evaluators.task_evaluator.TrainingConverter")
|
||||||
def test_evaluate_training_data(converter_mock):
|
def test_evaluate_training_data(converter_mock):
|
||||||
training_data = {
|
training_data = {
|
||||||
"agent_id": {
|
"agent_id": {
|
||||||
@@ -63,3 +65,39 @@ def test_evaluate_training_data(converter_mock):
|
|||||||
mock.call().to_pydantic(),
|
mock.call().to_pydantic(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch("crewai.utilities.converter.Converter.to_pydantic")
|
||||||
|
@patch("crewai.utilities.training_converter.TrainingConverter._convert_field_by_field")
|
||||||
|
def test_training_converter_fallback_mechanism(convert_field_by_field_mock, to_pydantic_mock):
|
||||||
|
training_data = {
|
||||||
|
"agent_id": {
|
||||||
|
"data1": {
|
||||||
|
"initial_output": "Initial output 1",
|
||||||
|
"human_feedback": "Human feedback 1",
|
||||||
|
"improved_output": "Improved output 1",
|
||||||
|
},
|
||||||
|
"data2": {
|
||||||
|
"initial_output": "Initial output 2",
|
||||||
|
"human_feedback": "Human feedback 2",
|
||||||
|
"improved_output": "Improved output 2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
agent_id = "agent_id"
|
||||||
|
to_pydantic_mock.side_effect = ConverterError("Failed to convert directly")
|
||||||
|
|
||||||
|
expected_result = TrainingTaskEvaluation(
|
||||||
|
suggestions=["Fallback suggestion"],
|
||||||
|
quality=6.5,
|
||||||
|
final_summary="Fallback summary"
|
||||||
|
)
|
||||||
|
convert_field_by_field_mock.return_value = expected_result
|
||||||
|
|
||||||
|
original_agent = MagicMock()
|
||||||
|
result = TaskEvaluator(original_agent=original_agent).evaluate_training_data(
|
||||||
|
training_data, agent_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == expected_result
|
||||||
|
to_pydantic_mock.assert_called_once()
|
||||||
|
convert_field_by_field_mock.assert_called_once()
|
||||||
|
|||||||
97
tests/utilities/test_training_converter.py
Normal file
97
tests/utilities/test_training_converter.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from crewai.utilities.converter import ConverterError
|
||||||
|
from crewai.utilities.training_converter import TrainingConverter
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(BaseModel):
|
||||||
|
string_field: str = Field(description="A simple string field")
|
||||||
|
list_field: List[str] = Field(description="A list of strings")
|
||||||
|
number_field: float = Field(description="A number field")
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainingConverter:
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
self.llm_mock = MagicMock()
|
||||||
|
self.test_text = "Sample text for evaluation"
|
||||||
|
self.test_instructions = "Convert to JSON format"
|
||||||
|
self.converter = TrainingConverter(
|
||||||
|
llm=self.llm_mock,
|
||||||
|
text=self.test_text,
|
||||||
|
model=TestModel,
|
||||||
|
instructions=self.test_instructions
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("crewai.utilities.converter.Converter.to_pydantic")
|
||||||
|
def test_fallback_to_field_by_field(self, parent_to_pydantic_mock):
|
||||||
|
parent_to_pydantic_mock.side_effect = ConverterError("Failed to convert directly")
|
||||||
|
|
||||||
|
llm_responses = {
|
||||||
|
"string_field": "test string value",
|
||||||
|
"list_field": "- item1\n- item2\n- item3",
|
||||||
|
"number_field": "8.5"
|
||||||
|
}
|
||||||
|
|
||||||
|
def llm_side_effect(messages):
|
||||||
|
prompt = messages[1]["content"]
|
||||||
|
if "string_field" in prompt:
|
||||||
|
return llm_responses["string_field"]
|
||||||
|
elif "list_field" in prompt:
|
||||||
|
return llm_responses["list_field"]
|
||||||
|
elif "number_field" in prompt:
|
||||||
|
return llm_responses["number_field"]
|
||||||
|
return "unknown field"
|
||||||
|
|
||||||
|
self.llm_mock.call.side_effect = llm_side_effect
|
||||||
|
|
||||||
|
result = self.converter.to_pydantic()
|
||||||
|
|
||||||
|
assert result.string_field == "test string value"
|
||||||
|
assert result.list_field == ["item1", "item2", "item3"]
|
||||||
|
assert result.number_field == 8.5
|
||||||
|
|
||||||
|
parent_to_pydantic_mock.assert_called_once()
|
||||||
|
assert self.llm_mock.call.call_count == 3
|
||||||
|
|
||||||
|
def test_ask_llm_for_field(self):
|
||||||
|
field_name = "test_field"
|
||||||
|
field_description = "This is a test field description"
|
||||||
|
expected_response = "Test response"
|
||||||
|
self.llm_mock.call.return_value = expected_response
|
||||||
|
response = self.converter._ask_llm_for_field(field_name, field_description)
|
||||||
|
|
||||||
|
assert response == expected_response
|
||||||
|
self.llm_mock.call.assert_called_once()
|
||||||
|
|
||||||
|
call_args = self.llm_mock.call.call_args[0][0]
|
||||||
|
assert call_args[0]["role"] == "system"
|
||||||
|
assert f"Extract the {field_name}" in call_args[0]["content"]
|
||||||
|
assert call_args[1]["role"] == "user"
|
||||||
|
assert field_name in call_args[1]["content"]
|
||||||
|
assert field_description in call_args[1]["content"]
|
||||||
|
|
||||||
|
def test_process_field_value_string(self):
|
||||||
|
response = " This is a string with extra whitespace "
|
||||||
|
result = self.converter._process_field_value(response, str)
|
||||||
|
assert result == "This is a string with extra whitespace"
|
||||||
|
|
||||||
|
def test_process_field_value_list_with_bullet_points(self):
|
||||||
|
response = "- Item 1\n- Item 2\n- Item 3"
|
||||||
|
result = self.converter._process_field_value(response, List[str])
|
||||||
|
assert result == ["Item 1", "Item 2", "Item 3"]
|
||||||
|
|
||||||
|
def test_process_field_value_list_with_json(self):
|
||||||
|
response = '["Item 1", "Item 2", "Item 3"]'
|
||||||
|
with patch("crewai.utilities.training_converter.json.loads") as json_mock:
|
||||||
|
json_mock.return_value = ["Item 1", "Item 2", "Item 3"]
|
||||||
|
result = self.converter._process_field_value(response, List[str])
|
||||||
|
assert result == ["Item 1", "Item 2", "Item 3"]
|
||||||
|
|
||||||
|
def test_process_field_value_float(self):
|
||||||
|
response = "The quality score is 8.5 out of 10"
|
||||||
|
result = self.converter._process_field_value(response, float)
|
||||||
|
assert result == 8.5
|
||||||
Reference in New Issue
Block a user