feat: improve data training for models up to 7B parameters.

This commit is contained in:
Lucas Gomide
2025-06-30 09:50:06 -03:00
parent 576b8ff836
commit 994f0e1403
4 changed files with 227 additions and 2 deletions

View File

@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
from crewai.utilities import Converter
from crewai.utilities.events import TaskEvaluationEvent, crewai_event_bus
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
from crewai.utilities.training_converter import TrainingConverter
class Entity(BaseModel):
@@ -133,7 +134,7 @@ class TaskEvaluator:
).get_schema()
instructions = f"{instructions}\n\nThe json should have the following structure, with the following keys:\n{model_schema}"
converter = Converter(
converter = TrainingConverter(
llm=self.llm,
text=evaluation_query,
model=TrainingTaskEvaluation,

View File

@@ -0,0 +1,89 @@
import json
import re
from typing import Any, get_origin
from pydantic import BaseModel, ValidationError
from crewai.utilities.converter import Converter, ConverterError
class TrainingConverter(Converter):
"""
A specialized converter for smaller LLMs (up to 7B parameters) that handles validation errors
by breaking down the model into individual fields and querying the LLM for each field separately.
"""
def to_pydantic(self, current_attempt=1) -> BaseModel:
try:
return super().to_pydantic(current_attempt)
except ConverterError:
return self._convert_field_by_field()
def _convert_field_by_field(self) -> BaseModel:
field_values = {}
for field_name, field_info in self.model.model_fields.items():
field_description = field_info.description
field_type = field_info.annotation
response = self._ask_llm_for_field(field_name, field_description)
value = self._process_field_value(response, field_type)
field_values[field_name] = value
try:
return self.model(**field_values)
except ValidationError as e:
raise ConverterError(f"Failed to create model from individually collected fields: {e}")
def _ask_llm_for_field(self, field_name: str, field_description: str) -> str:
prompt = f"""
Based on the following information:
{self.text}
Please provide ONLY the {field_name} field value as described:
"{field_description}"
Respond with ONLY the requested information, nothing else.
"""
return self.llm.call([
{"role": "system", "content": f"Extract the {field_name} from the previous information."},
{"role": "user", "content": prompt}
])
def _process_field_value(self, response: str, field_type: Any) -> Any:
response = response.strip()
origin = get_origin(field_type)
if origin is list:
return self._parse_list(response)
if field_type is float:
return self._parse_float(response)
if field_type is str:
return response
return response
def _parse_list(self, response: str) -> list:
try:
if response.startswith('['):
return json.loads(response)
items = [item.strip() for item in response.split('\n') if item.strip()]
return [self._strip_bullet(item) for item in items]
except json.JSONDecodeError:
return [response]
def _parse_float(self, response: str) -> float:
try:
match = re.search(r'(\d+(\.\d+)?)', response)
return float(match.group(1)) if match else 0.0
except Exception:
return 0.0
def _strip_bullet(self, item: str) -> str:
if item.startswith(('- ', '* ')):
return item[2:].strip()
return item.strip()

View File

@@ -1,13 +1,15 @@
from unittest import mock
from unittest.mock import MagicMock, patch
from crewai.utilities.evaluators.task_evaluator import (
TaskEvaluator,
TrainingTaskEvaluation,
)
from crewai.utilities.converter import ConverterError
@patch("crewai.utilities.evaluators.task_evaluator.Converter")
@patch("crewai.utilities.evaluators.task_evaluator.TrainingConverter")
def test_evaluate_training_data(converter_mock):
training_data = {
"agent_id": {
@@ -63,3 +65,39 @@ def test_evaluate_training_data(converter_mock):
mock.call().to_pydantic(),
]
)
@patch("crewai.utilities.converter.Converter.to_pydantic")
@patch("crewai.utilities.training_converter.TrainingConverter._convert_field_by_field")
def test_training_converter_fallback_mechanism(convert_field_by_field_mock, to_pydantic_mock):
training_data = {
"agent_id": {
"data1": {
"initial_output": "Initial output 1",
"human_feedback": "Human feedback 1",
"improved_output": "Improved output 1",
},
"data2": {
"initial_output": "Initial output 2",
"human_feedback": "Human feedback 2",
"improved_output": "Improved output 2",
},
}
}
agent_id = "agent_id"
to_pydantic_mock.side_effect = ConverterError("Failed to convert directly")
expected_result = TrainingTaskEvaluation(
suggestions=["Fallback suggestion"],
quality=6.5,
final_summary="Fallback summary"
)
convert_field_by_field_mock.return_value = expected_result
original_agent = MagicMock()
result = TaskEvaluator(original_agent=original_agent).evaluate_training_data(
training_data, agent_id
)
assert result == expected_result
to_pydantic_mock.assert_called_once()
convert_field_by_field_mock.assert_called_once()

View File

@@ -0,0 +1,97 @@
from unittest.mock import MagicMock, patch
from pydantic import BaseModel, Field
from typing import List
from crewai.utilities.converter import ConverterError
from crewai.utilities.training_converter import TrainingConverter
class TestModel(BaseModel):
string_field: str = Field(description="A simple string field")
list_field: List[str] = Field(description="A list of strings")
number_field: float = Field(description="A number field")
class TestTrainingConverter:
def setup_method(self):
self.llm_mock = MagicMock()
self.test_text = "Sample text for evaluation"
self.test_instructions = "Convert to JSON format"
self.converter = TrainingConverter(
llm=self.llm_mock,
text=self.test_text,
model=TestModel,
instructions=self.test_instructions
)
@patch("crewai.utilities.converter.Converter.to_pydantic")
def test_fallback_to_field_by_field(self, parent_to_pydantic_mock):
parent_to_pydantic_mock.side_effect = ConverterError("Failed to convert directly")
llm_responses = {
"string_field": "test string value",
"list_field": "- item1\n- item2\n- item3",
"number_field": "8.5"
}
def llm_side_effect(messages):
prompt = messages[1]["content"]
if "string_field" in prompt:
return llm_responses["string_field"]
elif "list_field" in prompt:
return llm_responses["list_field"]
elif "number_field" in prompt:
return llm_responses["number_field"]
return "unknown field"
self.llm_mock.call.side_effect = llm_side_effect
result = self.converter.to_pydantic()
assert result.string_field == "test string value"
assert result.list_field == ["item1", "item2", "item3"]
assert result.number_field == 8.5
parent_to_pydantic_mock.assert_called_once()
assert self.llm_mock.call.call_count == 3
def test_ask_llm_for_field(self):
field_name = "test_field"
field_description = "This is a test field description"
expected_response = "Test response"
self.llm_mock.call.return_value = expected_response
response = self.converter._ask_llm_for_field(field_name, field_description)
assert response == expected_response
self.llm_mock.call.assert_called_once()
call_args = self.llm_mock.call.call_args[0][0]
assert call_args[0]["role"] == "system"
assert f"Extract the {field_name}" in call_args[0]["content"]
assert call_args[1]["role"] == "user"
assert field_name in call_args[1]["content"]
assert field_description in call_args[1]["content"]
def test_process_field_value_string(self):
response = " This is a string with extra whitespace "
result = self.converter._process_field_value(response, str)
assert result == "This is a string with extra whitespace"
def test_process_field_value_list_with_bullet_points(self):
response = "- Item 1\n- Item 2\n- Item 3"
result = self.converter._process_field_value(response, List[str])
assert result == ["Item 1", "Item 2", "Item 3"]
def test_process_field_value_list_with_json(self):
response = '["Item 1", "Item 2", "Item 3"]'
with patch("crewai.utilities.training_converter.json.loads") as json_mock:
json_mock.return_value = ["Item 1", "Item 2", "Item 3"]
result = self.converter._process_field_value(response, List[str])
assert result == ["Item 1", "Item 2", "Item 3"]
def test_process_field_value_float(self):
response = "The quality score is 8.5 out of 10"
result = self.converter._process_field_value(response, float)
assert result == 8.5