mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
Fix CI: Make pgvector an optional dependency, fix SQL injection and type errors
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -67,6 +67,11 @@ docling = [
|
|||||||
aisuite = [
|
aisuite = [
|
||||||
"aisuite>=0.1.10",
|
"aisuite>=0.1.10",
|
||||||
]
|
]
|
||||||
|
pgvector = [
|
||||||
|
"pgvector>=0.2.0",
|
||||||
|
"sqlalchemy>=2.0.0",
|
||||||
|
"psycopg2-binary>=2.9.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
dev-dependencies = [
|
dev-dependencies = [
|
||||||
|
|||||||
@@ -1 +1,5 @@
|
|||||||
from crewai.knowledge.storage.pgvector_knowledge_storage import PGVectorKnowledgeStorage
|
try:
|
||||||
|
from crewai.knowledge.storage.pgvector_knowledge_storage import PGVectorKnowledgeStorage
|
||||||
|
__all__ = ["PGVectorKnowledgeStorage"]
|
||||||
|
except ImportError:
|
||||||
|
__all__ = []
|
||||||
|
|||||||
@@ -2,24 +2,34 @@ from typing import Any, Dict, List, Optional
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from sqlalchemy import create_engine, Column, String, Text, Float
|
from sqlalchemy import create_engine, Column, String, Text
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from pgvector.sqlalchemy import Vector
|
from sqlalchemy.sql import text
|
||||||
|
|
||||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||||
from crewai.utilities import EmbeddingConfigurator
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pgvector.sqlalchemy import Vector
|
||||||
|
HAS_PGVECTOR = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_PGVECTOR = False
|
||||||
|
class VectorType:
|
||||||
|
def __init__(self, dimensions: int):
|
||||||
|
self.dimensions = dimensions
|
||||||
|
Vector = VectorType # type: ignore
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
class Document(Base):
|
class Document(Base): # type: ignore
|
||||||
"""SQLAlchemy model for document storage with pgvector."""
|
"""SQLAlchemy model for document storage with pgvector."""
|
||||||
__tablename__ = "documents"
|
__tablename__ = "documents"
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True)
|
||||||
content = Column(Text)
|
content = Column(Text)
|
||||||
metadata = Column(Text) # JSON serialized metadata
|
metadata = Column(Text) # JSON serialized metadata
|
||||||
embedding = Column(Vector(1536)) # Adjust dimension based on embedding model
|
embedding: Column = Column(Vector(1536)) # Adjust dimension based on embedding model
|
||||||
|
|
||||||
class PGVectorKnowledgeStorage(BaseKnowledgeStorage):
|
class PGVectorKnowledgeStorage(BaseKnowledgeStorage):
|
||||||
"""
|
"""
|
||||||
@@ -45,6 +55,11 @@ class PGVectorKnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
table_name: Name of the table to store documents
|
table_name: Name of the table to store documents
|
||||||
embedding_dimension: Dimension of the embedding vectors
|
embedding_dimension: Dimension of the embedding vectors
|
||||||
"""
|
"""
|
||||||
|
if not HAS_PGVECTOR:
|
||||||
|
raise ImportError(
|
||||||
|
"pgvector is not installed. Please install it with: pip install pgvector"
|
||||||
|
)
|
||||||
|
|
||||||
self.connection_string = connection_string
|
self.connection_string = connection_string
|
||||||
self.table_name = table_name
|
self.table_name = table_name
|
||||||
self.embedding_dimension = embedding_dimension
|
self.embedding_dimension = embedding_dimension
|
||||||
@@ -94,14 +109,17 @@ class PGVectorKnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
try:
|
try:
|
||||||
query_embedding = self.embedder([query[0]])[0]
|
query_embedding = self.embedder([query[0]])[0]
|
||||||
|
|
||||||
sql_query = f"""
|
sql_query = text(f"""
|
||||||
SELECT id, content, metadata, 1 - (embedding <=> '{query_embedding}') as similarity
|
SELECT id, content, metadata, 1 - (embedding <=> :query_embedding) as similarity
|
||||||
FROM {self.table_name}
|
FROM {self.table_name}
|
||||||
ORDER BY embedding <=> '{query_embedding}'
|
ORDER BY embedding <=> :query_embedding
|
||||||
LIMIT {limit}
|
LIMIT :limit
|
||||||
"""
|
""")
|
||||||
|
|
||||||
results = session.execute(sql_query).fetchall()
|
results = session.execute(
|
||||||
|
sql_query,
|
||||||
|
{"query_embedding": query_embedding, "limit": limit}
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
formatted_results = []
|
formatted_results = []
|
||||||
for row in results:
|
for row in results:
|
||||||
@@ -154,9 +172,9 @@ class PGVectorKnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
existing = session.query(Document).filter(Document.id == doc_id).first()
|
existing = session.query(Document).filter(Document.id == doc_id).first()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
existing.content = doc
|
setattr(existing, "content", doc)
|
||||||
existing.metadata = str(meta) if meta else None
|
setattr(existing, "metadata", str(meta) if meta else None)
|
||||||
existing.embedding = embedding
|
setattr(existing, "embedding", embedding)
|
||||||
else:
|
else:
|
||||||
new_doc = Document(
|
new_doc = Document(
|
||||||
id=doc_id,
|
id=doc_id,
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
import os
|
|
||||||
import pytest
|
import pytest
|
||||||
from typing import Dict, Any, List
|
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
from crewai.knowledge.storage.pgvector_knowledge_storage import PGVectorKnowledgeStorage
|
from crewai.knowledge.storage.pgvector_knowledge_storage import PGVectorKnowledgeStorage
|
||||||
|
|||||||
Reference in New Issue
Block a user