mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
implementing initial LLM class
This commit is contained in:
@@ -51,6 +51,50 @@ def test_custom_llm_with_langchain():
|
||||
assert agent.llm == "gpt-4"
|
||||
|
||||
|
||||
def test_custom_llm_temperature_preservation():
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
langchain_llm = ChatOpenAI(temperature=0.7, model="gpt-4")
|
||||
agent = Agent(
|
||||
role="temperature test role",
|
||||
goal="temperature test goal",
|
||||
backstory="temperature test backstory",
|
||||
llm=langchain_llm,
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert agent.llm.model == "gpt-4"
|
||||
assert agent.llm.temperature == 0.7
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_execute_task():
|
||||
from langchain_openai import ChatOpenAI
|
||||
from crewai import Task
|
||||
|
||||
agent = Agent(
|
||||
role="Math Tutor",
|
||||
goal="Solve math problems accurately",
|
||||
backstory="You are an experienced math tutor with a knack for explaining complex concepts simply.",
|
||||
llm=ChatOpenAI(temperature=0.7, model="gpt-4o-mini"),
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Calculate the area of a circle with radius 5 cm.",
|
||||
expected_output="The calculated area of the circle in square centimeters.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
assert result is not None
|
||||
assert (
|
||||
"The area of a circle with a radius of 5 cm is calculated using the formula A = πr^2, where A is the area and r is the...in the values, we get A = π*5^2 = 25π square centimeters. Therefore, the area of the circle is 25π square centimeters."
|
||||
in result
|
||||
)
|
||||
assert "square centimeters" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_execution():
|
||||
agent = Agent(
|
||||
@@ -67,7 +111,7 @@ def test_agent_execution():
|
||||
)
|
||||
|
||||
output = agent.execute_task(task)
|
||||
assert output == "1 + 1 = 2"
|
||||
assert output == "The result of the math operation 1 + 1 is 2."
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -182,7 +226,7 @@ def test_cache_hitting():
|
||||
task = Task(
|
||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
||||
agent=agent,
|
||||
expected_output="The number that is the result of the multiplication.",
|
||||
expected_output="The number that is the result of the multiplication tool.",
|
||||
)
|
||||
output = agent.execute_task(task)
|
||||
assert output == "0"
|
||||
@@ -1102,6 +1146,88 @@ def test_agent_max_retry_limit():
|
||||
)
|
||||
|
||||
|
||||
def test_agent_with_llm():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(model="gpt-3.5-turbo", temperature=0.7),
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert agent.llm.model == "gpt-3.5-turbo"
|
||||
assert agent.llm.temperature == 0.7
|
||||
|
||||
|
||||
def test_agent_with_custom_stop_words():
|
||||
stop_words = ["STOP", "END"]
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(model="gpt-3.5-turbo", stop=stop_words),
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert agent.llm.stop == stop_words
|
||||
|
||||
|
||||
def test_agent_with_callbacks():
|
||||
def dummy_callback(response):
|
||||
pass
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(model="gpt-3.5-turbo", callbacks=[dummy_callback]),
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert len(agent.llm.callbacks) == 1
|
||||
assert agent.llm.callbacks[0] == dummy_callback
|
||||
|
||||
|
||||
def test_agent_with_additional_kwargs():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(
|
||||
model="gpt-3.5-turbo",
|
||||
temperature=0.8,
|
||||
top_p=0.9,
|
||||
presence_penalty=0.1,
|
||||
frequency_penalty=0.1,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert agent.llm.model == "gpt-3.5-turbo"
|
||||
assert agent.llm.temperature == 0.8
|
||||
assert agent.llm.top_p == 0.9
|
||||
assert agent.llm.presence_penalty == 0.1
|
||||
assert agent.llm.frequency_penalty == 0.1
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call():
|
||||
llm = LLM(model="gpt-3.5-turbo")
|
||||
messages = [{"role": "user", "content": "Say 'Hello, World!'"}]
|
||||
|
||||
response = llm.call(messages)
|
||||
assert "Hello, World!" in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_with_error():
|
||||
llm = LLM(model="non-existent-model")
|
||||
messages = [{"role": "user", "content": "This should fail"}]
|
||||
|
||||
with pytest.raises(Exception):
|
||||
llm.call(messages)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_handle_context_length_exceeds_limit():
|
||||
agent = Agent(
|
||||
@@ -1172,3 +1298,213 @@ def test_handle_context_length_exceeds_limit_cli_no():
|
||||
CrewAgentExecutor, "_handle_context_length"
|
||||
) as mock_handle_context:
|
||||
mock_handle_context.assert_not_called()
|
||||
|
||||
|
||||
def test_agent_with_all_llm_attributes():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(
|
||||
model="gpt-3.5-turbo",
|
||||
timeout=10,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
n=1,
|
||||
stop=["STOP", "END"],
|
||||
max_tokens=100,
|
||||
presence_penalty=0.1,
|
||||
frequency_penalty=0.1,
|
||||
logit_bias={50256: -100}, # Example: bias against the EOT token
|
||||
response_format={"type": "json_object"},
|
||||
seed=42,
|
||||
logprobs=True,
|
||||
top_logprobs=5,
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_version="2023-05-15",
|
||||
api_key="sk-your-api-key-here",
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert agent.llm.model == "gpt-3.5-turbo"
|
||||
assert agent.llm.timeout == 10
|
||||
assert agent.llm.temperature == 0.7
|
||||
assert agent.llm.top_p == 0.9
|
||||
assert agent.llm.n == 1
|
||||
assert agent.llm.stop == ["STOP", "END"]
|
||||
assert agent.llm.max_tokens == 100
|
||||
assert agent.llm.presence_penalty == 0.1
|
||||
assert agent.llm.frequency_penalty == 0.1
|
||||
assert agent.llm.logit_bias == {50256: -100}
|
||||
assert agent.llm.response_format == {"type": "json_object"}
|
||||
assert agent.llm.seed == 42
|
||||
assert agent.llm.logprobs
|
||||
assert agent.llm.top_logprobs == 5
|
||||
assert agent.llm.base_url == "https://api.openai.com/v1"
|
||||
assert agent.llm.api_version == "2023-05-15"
|
||||
assert agent.llm.api_key == "sk-your-api-key-here"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_with_all_attributes():
|
||||
llm = LLM(
|
||||
model="gpt-3.5-turbo",
|
||||
temperature=0.7,
|
||||
max_tokens=50,
|
||||
stop=["STOP"],
|
||||
presence_penalty=0.1,
|
||||
frequency_penalty=0.1,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Say 'Hello, World!' and then say STOP"}]
|
||||
|
||||
response = llm.call(messages)
|
||||
assert "Hello, World!" in response
|
||||
assert "STOP" not in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_with_ollama_gemma():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(
|
||||
model="ollama/gemma2:latest",
|
||||
base_url="http://localhost:8080",
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert agent.llm.model == "ollama/gemma2:latest"
|
||||
assert agent.llm.base_url == "http://localhost:8080"
|
||||
|
||||
task = "Respond in 20 words. Who are you?"
|
||||
response = agent.llm.call([{"role": "user", "content": task}])
|
||||
|
||||
assert response
|
||||
assert len(response.split()) <= 25 # Allow a little flexibility in word count
|
||||
assert "Gemma" in response or "AI" in response or "language model" in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_with_ollama_gemma():
|
||||
llm = LLM(
|
||||
model="ollama/gemma2:latest",
|
||||
base_url="http://localhost:8080",
|
||||
temperature=0.7,
|
||||
max_tokens=30,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Respond in 20 words. Who are you?"}]
|
||||
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response
|
||||
assert len(response.split()) <= 25 # Allow a little flexibility in word count
|
||||
assert "Gemma" in response or "AI" in response or "language model" in response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_execute_task_basic():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(model="gpt-3.5-turbo"),
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Calculate 2 + 2",
|
||||
expected_output="The result of the calculation",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
assert "4" in result
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_execute_task_with_context():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(model="gpt-3.5-turbo"),
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Summarize the given context in one sentence",
|
||||
expected_output="A one-sentence summary",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
context = "The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet."
|
||||
|
||||
result = agent.execute_task(task, context=context)
|
||||
assert len(result.split(".")) == 3
|
||||
assert "fox" in result.lower() and "dog" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_execute_task_with_tool():
|
||||
@tool
|
||||
def dummy_tool(query: str) -> str:
|
||||
"""Useful for when you need to get a dummy result for a query."""
|
||||
return f"Dummy result for: {query}"
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(model="gpt-3.5-turbo"),
|
||||
tools=[dummy_tool],
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Use the dummy tool to get a result for 'test query'",
|
||||
expected_output="The result from the dummy tool",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
assert "Dummy result for: test query" in result
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_execute_task_with_custom_llm():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(model="gpt-3.5-turbo", temperature=0.7, max_tokens=50),
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Write a haiku about AI",
|
||||
expected_output="A haiku (3 lines, 5-7-5 syllable pattern) about AI",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
assert len(result.split("\n")) == 3
|
||||
assert "AI" in result or "artificial intelligence" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_execute_task_with_ollama():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(model="ollama/gemma2:latest", base_url="http://localhost:8080"),
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Explain what AI is in one sentence",
|
||||
expected_output="A one-sentence explanation of AI",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
assert len(result.split(".")) == 2
|
||||
assert "AI" in result or "artificial intelligence" in result.lower()
|
||||
|
||||
Reference in New Issue
Block a user