mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
adding in lorenze feedback
This commit is contained in:
@@ -126,7 +126,7 @@ class Agent(BaseAgent):
|
|||||||
default="safe",
|
default="safe",
|
||||||
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
|
||||||
)
|
)
|
||||||
# TODO: We need to add in knowledge config (score, top_k, etc)
|
# TODO: Lorenze add knowledge_embedder. Support direct class or config dict.
|
||||||
_knowledge: Optional[Knowledge] = PrivateAttr(default=None)
|
_knowledge: Optional[Knowledge] = PrivateAttr(default=None)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@@ -279,10 +279,8 @@ class Agent(BaseAgent):
|
|||||||
if self._knowledge:
|
if self._knowledge:
|
||||||
# Query the knowledge base for relevant information
|
# Query the knowledge base for relevant information
|
||||||
knowledge_snippets = self._knowledge.query(query=task.prompt())
|
knowledge_snippets = self._knowledge.query(query=task.prompt())
|
||||||
print("knowledge_snippets", knowledge_snippets)
|
|
||||||
if knowledge_snippets:
|
if knowledge_snippets:
|
||||||
formatted_knowledge = "\n".join(knowledge_snippets)
|
formatted_knowledge = "\n".join(knowledge_snippets)
|
||||||
print("formatted_knowledge", formatted_knowledge)
|
|
||||||
task_prompt += f"\n\nAdditional Information:\n{formatted_knowledge}"
|
task_prompt += f"\n\nAdditional Information:\n{formatted_knowledge}"
|
||||||
|
|
||||||
tools = tools or self.tools or []
|
tools = tools or self.tools or []
|
||||||
|
|||||||
@@ -0,0 +1,82 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from .base_embedder import BaseEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbedder(BaseEmbedder):
|
||||||
|
"""
|
||||||
|
A wrapper class for text embedding models using Ollama's API
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: str = "http://localhost:11434/v1",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the embedding model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model to use
|
||||||
|
api_key: API key (defaults to 'ollama' or environment variable 'OLLAMA_API_KEY')
|
||||||
|
base_url: Base URL for the Ollama API (default is 'http://localhost:11434/v1')
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.api_key = api_key or os.getenv("OLLAMA_API_KEY") or "ollama"
|
||||||
|
self.base_url = base_url
|
||||||
|
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
||||||
|
|
||||||
|
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for a list of text chunks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of text chunks to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings
|
||||||
|
"""
|
||||||
|
return self.embed_texts(chunks)
|
||||||
|
|
||||||
|
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for a list of texts
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
max_batch_size = 2048 # Adjust batch size if necessary
|
||||||
|
for i in range(0, len(texts), max_batch_size):
|
||||||
|
batch = texts[i : i + max_batch_size]
|
||||||
|
response = self.client.embeddings.create(input=batch, model=self.model_name)
|
||||||
|
batch_embeddings = [np.array(item.embedding) for item in response.data]
|
||||||
|
embeddings.extend(batch_embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def embed_text(self, text: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Generate embedding for a single text
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding array
|
||||||
|
"""
|
||||||
|
return self.embed_texts([text])[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimension(self) -> int:
|
||||||
|
"""Get the dimension of the embeddings"""
|
||||||
|
# Embedding dimensions may vary; we'll determine it dynamically
|
||||||
|
test_embed = self.embed_text("test")
|
||||||
|
return len(test_embed)
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from openai import OpenAI
|
|
||||||
|
|
||||||
from .base_embedder import BaseEmbedder
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaEmbedder(BaseEmbedder):
|
|
||||||
"""
|
|
||||||
A wrapper class for text embedding models using Ollama's API
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
base_url: str = "http://localhost:11434/v1",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the embedding model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the model to use
|
|
||||||
api_key: API key (defaults to 'ollama' or environment variable 'OLLAMA_API_KEY')
|
|
||||||
base_url: Base URL for the Ollama API (default is 'http://localhost:11434/v1')
|
|
||||||
"""
|
|
||||||
self.model_name = model_name
|
|
||||||
self.api_key = api_key or os.getenv("OLLAMA_API_KEY") or "ollama"
|
|
||||||
self.base_url = base_url
|
|
||||||
self.client = OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
||||||
|
|
||||||
def embed_chunks(self, chunks: List[str]) -> List[np.ndarray]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for a list of text chunks
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunks: List of text chunks to embed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embeddings
|
|
||||||
"""
|
|
||||||
return self.embed_texts(chunks)
|
|
||||||
|
|
||||||
def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
|
|
||||||
"""
|
|
||||||
Generate embeddings for a list of texts
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to embed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of embeddings
|
|
||||||
"""
|
|
||||||
embeddings = []
|
|
||||||
max_batch_size = 2048 # Adjust batch size if necessary
|
|
||||||
for i in range(0, len(texts), max_batch_size):
|
|
||||||
batch = texts[i : i + max_batch_size]
|
|
||||||
response = self.client.embeddings.create(input=batch, model=self.model_name)
|
|
||||||
batch_embeddings = [np.array(item.embedding) for item in response.data]
|
|
||||||
embeddings.extend(batch_embeddings)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def embed_text(self, text: str) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Generate embedding for a single text
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to embed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Embedding array
|
|
||||||
"""
|
|
||||||
return self.embed_texts([text])[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dimension(self) -> int:
|
|
||||||
"""Get the dimension of the embeddings"""
|
|
||||||
# Embedding dimensions may vary; we'll determine it dynamically
|
|
||||||
test_embed = self.embed_text("test")
|
|
||||||
return len(test_embed)
|
|
||||||
Reference in New Issue
Block a user