Some checks failed
CI - SharePoint Plugin with SonarQube / Test and SonarQube Analysis (push) Has been cancelled
653 lines
24 KiB
Python
653 lines
24 KiB
Python
"""
|
|
PostgreSQL Vector Store with pgvector for Multi-Document Chat with RAG
|
|
|
|
Persistent vector storage using PostgreSQL with pgvector extension.
|
|
"""
|
|
|
|
import os
|
|
import hashlib
|
|
from typing import List, Dict, Any, Optional
|
|
from dataclasses import dataclass, asdict
|
|
from datetime import datetime
|
|
import psycopg2
|
|
from psycopg2.extras import RealDictCursor, execute_values
|
|
from pgvector.psycopg2 import register_vector
|
|
|
|
|
|
@dataclass
|
|
class DocumentChunk:
|
|
"""A chunk of a document with metadata."""
|
|
chunk_id: str
|
|
document_id: str
|
|
content: str
|
|
chunk_index: int
|
|
embedding: Optional[List[float]] = None
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
@dataclass
|
|
class Document:
|
|
"""Document metadata with tags."""
|
|
document_id: str
|
|
user_id: str
|
|
site_id: str
|
|
file_path: str
|
|
filename: str
|
|
tags: List[str]
|
|
content_hash: str
|
|
created_at: str
|
|
updated_at: str
|
|
chunk_count: int
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class EmbeddingProvider:
|
|
"""Base class for embedding providers."""
|
|
|
|
def embed_text(self, text: str) -> List[float]:
|
|
"""Generate embedding for text."""
|
|
raise NotImplementedError
|
|
|
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
"""Generate embeddings for multiple texts."""
|
|
return [self.embed_text(text) for text in texts]
|
|
|
|
|
|
class OllamaEmbeddings(EmbeddingProvider):
|
|
"""Ollama embeddings using local models."""
|
|
|
|
def __init__(self, base_url: str = "http://localhost:11434", model: str = "nomic-embed-text"):
|
|
"""
|
|
Initialize Ollama embeddings.
|
|
|
|
Args:
|
|
base_url: Ollama server URL
|
|
model: Embedding model (e.g., nomic-embed-text, mxbai-embed-large)
|
|
"""
|
|
import requests
|
|
self.base_url = base_url.rstrip('/')
|
|
self.model = model
|
|
self.requests = requests
|
|
|
|
def embed_text(self, text: str) -> List[float]:
|
|
"""Generate embedding using Ollama."""
|
|
response = self.requests.post(
|
|
f"{self.base_url}/api/embeddings",
|
|
json={
|
|
"model": self.model,
|
|
"prompt": text
|
|
}
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()["embedding"]
|
|
|
|
|
|
class OpenAIEmbeddings(EmbeddingProvider):
|
|
"""OpenAI embeddings."""
|
|
|
|
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
|
|
"""
|
|
Initialize OpenAI embeddings.
|
|
|
|
Args:
|
|
api_key: OpenAI API key
|
|
model: Embedding model (text-embedding-3-small or text-embedding-3-large)
|
|
"""
|
|
import requests
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self.requests = requests
|
|
|
|
def embed_text(self, text: str) -> List[float]:
|
|
"""Generate embedding using OpenAI."""
|
|
response = self.requests.post(
|
|
"https://api.openai.com/v1/embeddings",
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"input": text,
|
|
"model": self.model
|
|
}
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()["data"][0]["embedding"]
|
|
|
|
|
|
class PostgreSQLVectorStore:
|
|
"""
|
|
PostgreSQL vector store with pgvector extension.
|
|
|
|
Provides persistent storage for document embeddings with tag-based organization.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_provider: EmbeddingProvider,
|
|
connection_string: Optional[str] = None,
|
|
table_prefix: str = ""
|
|
):
|
|
"""
|
|
Initialize PostgreSQL vector store.
|
|
|
|
Args:
|
|
embedding_provider: Provider for generating embeddings
|
|
connection_string: PostgreSQL connection string (uses env vars if not provided)
|
|
table_prefix: Prefix for table names (e.g., "dev_" or "prod_")
|
|
"""
|
|
self.embedding_provider = embedding_provider
|
|
self.table_prefix = table_prefix
|
|
|
|
# Use connection string or build from env vars
|
|
if connection_string:
|
|
self.connection_string = connection_string
|
|
else:
|
|
self.connection_string = (
|
|
f"host={os.getenv('POSTGRES_HOST', 'localhost')} "
|
|
f"port={os.getenv('POSTGRES_PORT', '5432')} "
|
|
f"dbname={os.getenv('POSTGRES_DB', 'sharepoint_vectors')} "
|
|
f"user={os.getenv('POSTGRES_USER', 'postgres')} "
|
|
f"password={os.getenv('POSTGRES_PASSWORD', 'postgres')}"
|
|
)
|
|
|
|
# Initialize database connection
|
|
self.conn = psycopg2.connect(self.connection_string)
|
|
self.conn.autocommit = False
|
|
|
|
# Create tables and enable extension if they don't exist
|
|
self._create_tables()
|
|
|
|
# Register pgvector type (must be after extension is enabled)
|
|
register_vector(self.conn)
|
|
|
|
def _create_tables(self):
|
|
"""Create tables for documents, chunks, and tags."""
|
|
with self.conn.cursor() as cur:
|
|
# Enable pgvector extension
|
|
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
|
|
# Documents table
|
|
cur.execute(f"""
|
|
CREATE TABLE IF NOT EXISTS {self.table_prefix}documents (
|
|
document_id VARCHAR(64) PRIMARY KEY,
|
|
user_id VARCHAR(255) NOT NULL,
|
|
site_id VARCHAR(255) NOT NULL,
|
|
file_path TEXT NOT NULL,
|
|
filename VARCHAR(255) NOT NULL,
|
|
content_hash VARCHAR(64) NOT NULL,
|
|
created_at TIMESTAMP NOT NULL,
|
|
updated_at TIMESTAMP NOT NULL,
|
|
chunk_count INTEGER NOT NULL,
|
|
metadata JSONB,
|
|
UNIQUE(user_id, site_id, file_path)
|
|
)
|
|
""")
|
|
|
|
# Document chunks table with vector embeddings
|
|
# Assuming 768 dimensions for nomic-embed-text (adjust based on your model)
|
|
cur.execute(f"""
|
|
CREATE TABLE IF NOT EXISTS {self.table_prefix}document_chunks (
|
|
chunk_id VARCHAR(128) PRIMARY KEY,
|
|
document_id VARCHAR(64) NOT NULL REFERENCES {self.table_prefix}documents(document_id) ON DELETE CASCADE,
|
|
content TEXT NOT NULL,
|
|
chunk_index INTEGER NOT NULL,
|
|
embedding vector(768),
|
|
metadata JSONB,
|
|
UNIQUE(document_id, chunk_index)
|
|
)
|
|
""")
|
|
|
|
# Tags table (many-to-many with documents)
|
|
cur.execute(f"""
|
|
CREATE TABLE IF NOT EXISTS {self.table_prefix}document_tags (
|
|
document_id VARCHAR(64) NOT NULL REFERENCES {self.table_prefix}documents(document_id) ON DELETE CASCADE,
|
|
tag VARCHAR(100) NOT NULL,
|
|
PRIMARY KEY(document_id, tag)
|
|
)
|
|
""")
|
|
|
|
# Create indexes for performance
|
|
cur.execute(f"""
|
|
CREATE INDEX IF NOT EXISTS idx_documents_user_id
|
|
ON {self.table_prefix}documents(user_id)
|
|
""")
|
|
|
|
cur.execute(f"""
|
|
CREATE INDEX IF NOT EXISTS idx_document_tags_tag
|
|
ON {self.table_prefix}document_tags(tag)
|
|
""")
|
|
|
|
# Create IVFFlat index for vector similarity search (after some data is inserted)
|
|
# This is commented out - create it manually after inserting ~1000+ vectors:
|
|
# CREATE INDEX ON document_chunks USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
|
|
|
|
self.conn.commit()
|
|
|
|
def add_document(
|
|
self,
|
|
user_id: str,
|
|
site_id: str,
|
|
file_path: str,
|
|
filename: str,
|
|
content: str,
|
|
tags: List[str],
|
|
chunk_size: int = 1000,
|
|
chunk_overlap: int = 200,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> str:
|
|
"""
|
|
Add a document to the vector store with chunking and embeddings.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
site_id: SharePoint site ID
|
|
file_path: File path
|
|
filename: Filename
|
|
content: Document content
|
|
tags: List of tags (e.g., ["HR", "SALES", "Q4-2024"])
|
|
chunk_size: Size of each chunk in characters
|
|
chunk_overlap: Overlap between chunks
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
document_id
|
|
"""
|
|
# Generate document ID
|
|
document_id = self._generate_document_id(user_id, site_id, file_path)
|
|
|
|
# Calculate content hash
|
|
content_hash = hashlib.sha256(content.encode()).hexdigest()
|
|
|
|
with self.conn.cursor() as cur:
|
|
# Check if document already exists
|
|
cur.execute(
|
|
f"SELECT content_hash FROM {self.table_prefix}documents WHERE document_id = %s",
|
|
(document_id,)
|
|
)
|
|
existing = cur.fetchone()
|
|
|
|
if existing and existing[0] == content_hash:
|
|
# Content hasn't changed, just update tags if different
|
|
cur.execute(
|
|
f"SELECT tag FROM {self.table_prefix}document_tags WHERE document_id = %s",
|
|
(document_id,)
|
|
)
|
|
existing_tags = [row[0] for row in cur.fetchall()]
|
|
|
|
if set(tags) != set(existing_tags):
|
|
self._update_tags(cur, document_id, tags)
|
|
cur.execute(
|
|
f"UPDATE {self.table_prefix}documents SET updated_at = %s WHERE document_id = %s",
|
|
(datetime.utcnow(), document_id)
|
|
)
|
|
self.conn.commit()
|
|
|
|
return document_id
|
|
|
|
# Delete old chunks if content changed
|
|
if existing:
|
|
cur.execute(
|
|
f"DELETE FROM {self.table_prefix}document_chunks WHERE document_id = %s",
|
|
(document_id,)
|
|
)
|
|
|
|
# Chunk the document
|
|
chunks = self._chunk_text(content, chunk_size, chunk_overlap)
|
|
|
|
# Generate embeddings for all chunks
|
|
embeddings = self.embedding_provider.embed_batch(chunks)
|
|
|
|
# Insert or update document
|
|
now = datetime.utcnow()
|
|
if existing:
|
|
cur.execute(f"""
|
|
UPDATE {self.table_prefix}documents
|
|
SET content_hash = %s, updated_at = %s, chunk_count = %s, metadata = %s
|
|
WHERE document_id = %s
|
|
""", (content_hash, now, len(chunks), psycopg2.extras.Json(metadata), document_id))
|
|
else:
|
|
cur.execute(f"""
|
|
INSERT INTO {self.table_prefix}documents
|
|
(document_id, user_id, site_id, file_path, filename, content_hash, created_at, updated_at, chunk_count, metadata)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
""", (document_id, user_id, site_id, file_path, filename, content_hash, now, now, len(chunks), psycopg2.extras.Json(metadata)))
|
|
|
|
# Insert chunks with embeddings
|
|
chunk_data = [
|
|
(
|
|
f"{document_id}_chunk_{idx}",
|
|
document_id,
|
|
chunk_text,
|
|
idx,
|
|
embedding,
|
|
psycopg2.extras.Json(metadata)
|
|
)
|
|
for idx, (chunk_text, embedding) in enumerate(zip(chunks, embeddings))
|
|
]
|
|
|
|
execute_values(
|
|
cur,
|
|
f"""
|
|
INSERT INTO {self.table_prefix}document_chunks
|
|
(chunk_id, document_id, content, chunk_index, embedding, metadata)
|
|
VALUES %s
|
|
""",
|
|
chunk_data
|
|
)
|
|
|
|
# Update tags
|
|
self._update_tags(cur, document_id, tags)
|
|
|
|
self.conn.commit()
|
|
|
|
return document_id
|
|
|
|
def search(
|
|
self,
|
|
user_id: str,
|
|
query: str,
|
|
tags: Optional[List[str]] = None,
|
|
top_k: int = 5
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Search for relevant document chunks using vector similarity.
|
|
|
|
Args:
|
|
user_id: User ID (for access control)
|
|
query: Search query
|
|
tags: Optional list of tags to filter by
|
|
top_k: Number of results to return
|
|
|
|
Returns:
|
|
List of relevant chunks with similarity scores
|
|
"""
|
|
# Generate query embedding
|
|
query_embedding = self.embedding_provider.embed_text(query)
|
|
|
|
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
|
|
# Build query based on whether tags are specified
|
|
if tags:
|
|
# Filter by tags using JOIN
|
|
tag_placeholders = ','.join(['%s'] * len(tags))
|
|
cur.execute(f"""
|
|
SELECT
|
|
c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata as chunk_metadata,
|
|
d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at,
|
|
d.chunk_count, d.content_hash, d.metadata as doc_metadata,
|
|
ARRAY_AGG(dt.tag) as tags,
|
|
1 - (c.embedding <=> %s::vector) as similarity
|
|
FROM {self.table_prefix}document_chunks c
|
|
JOIN {self.table_prefix}documents d ON c.document_id = d.document_id
|
|
JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
|
|
WHERE d.user_id = %s AND LOWER(dt.tag) IN ({tag_placeholders})
|
|
GROUP BY c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata,
|
|
d.user_id, d.site_id, d.file_path, d.filename, d.created_at,
|
|
d.updated_at, d.chunk_count, d.content_hash, d.metadata, c.embedding
|
|
ORDER BY similarity DESC
|
|
LIMIT %s
|
|
""", [query_embedding, user_id] + [tag.lower() for tag in tags] + [top_k])
|
|
else:
|
|
# All user's documents
|
|
cur.execute(f"""
|
|
SELECT
|
|
c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata as chunk_metadata,
|
|
d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at,
|
|
d.chunk_count, d.content_hash, d.metadata as doc_metadata,
|
|
ARRAY_AGG(dt.tag) as tags,
|
|
1 - (c.embedding <=> %s::vector) as similarity
|
|
FROM {self.table_prefix}document_chunks c
|
|
JOIN {self.table_prefix}documents d ON c.document_id = d.document_id
|
|
LEFT JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
|
|
WHERE d.user_id = %s
|
|
GROUP BY c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata,
|
|
d.user_id, d.site_id, d.file_path, d.filename, d.created_at,
|
|
d.updated_at, d.chunk_count, d.content_hash, d.metadata, c.embedding
|
|
ORDER BY similarity DESC
|
|
LIMIT %s
|
|
""", [query_embedding, user_id, top_k])
|
|
|
|
rows = cur.fetchall()
|
|
|
|
# Convert to result format
|
|
results = []
|
|
for row in rows:
|
|
chunk = DocumentChunk(
|
|
chunk_id=row['chunk_id'],
|
|
document_id=row['document_id'],
|
|
content=row['content'],
|
|
chunk_index=row['chunk_index'],
|
|
metadata=row['chunk_metadata']
|
|
)
|
|
|
|
document = Document(
|
|
document_id=row['document_id'],
|
|
user_id=row['user_id'],
|
|
site_id=row['site_id'],
|
|
file_path=row['file_path'],
|
|
filename=row['filename'],
|
|
tags=row['tags'] or [],
|
|
content_hash=row['content_hash'],
|
|
created_at=row['created_at'].isoformat(),
|
|
updated_at=row['updated_at'].isoformat(),
|
|
chunk_count=row['chunk_count'],
|
|
metadata=row['doc_metadata']
|
|
)
|
|
|
|
results.append({
|
|
"chunk": chunk,
|
|
"document": document,
|
|
"similarity": float(row['similarity'])
|
|
})
|
|
|
|
return results
|
|
|
|
def get_documents_by_tags(self, user_id: str, tags: List[str]) -> List[Document]:
|
|
"""Get all documents with specific tags."""
|
|
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
|
|
tag_placeholders = ','.join(['%s'] * len(tags))
|
|
cur.execute(f"""
|
|
SELECT DISTINCT
|
|
d.document_id, d.user_id, d.site_id, d.file_path, d.filename,
|
|
d.content_hash, d.created_at, d.updated_at, d.chunk_count, d.metadata,
|
|
ARRAY_AGG(dt.tag) as tags
|
|
FROM {self.table_prefix}documents d
|
|
JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
|
|
WHERE d.user_id = %s AND LOWER(dt.tag) IN ({tag_placeholders})
|
|
GROUP BY d.document_id, d.user_id, d.site_id, d.file_path, d.filename,
|
|
d.content_hash, d.created_at, d.updated_at, d.chunk_count, d.metadata
|
|
""", [user_id] + [tag.lower() for tag in tags])
|
|
|
|
rows = cur.fetchall()
|
|
|
|
return [
|
|
Document(
|
|
document_id=row['document_id'],
|
|
user_id=row['user_id'],
|
|
site_id=row['site_id'],
|
|
file_path=row['file_path'],
|
|
filename=row['filename'],
|
|
tags=row['tags'] or [],
|
|
content_hash=row['content_hash'],
|
|
created_at=row['created_at'].isoformat(),
|
|
updated_at=row['updated_at'].isoformat(),
|
|
chunk_count=row['chunk_count'],
|
|
metadata=row['metadata']
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
def list_tags(self, user_id: str) -> Dict[str, int]:
|
|
"""List all tags for user with document counts."""
|
|
with self.conn.cursor() as cur:
|
|
cur.execute(f"""
|
|
SELECT dt.tag, COUNT(DISTINCT dt.document_id) as count
|
|
FROM {self.table_prefix}document_tags dt
|
|
JOIN {self.table_prefix}documents d ON dt.document_id = d.document_id
|
|
WHERE d.user_id = %s
|
|
GROUP BY dt.tag
|
|
ORDER BY count DESC, dt.tag
|
|
""", (user_id,))
|
|
|
|
return {row[0]: row[1] for row in cur.fetchall()}
|
|
|
|
def get_indexed_sites(self, user_id: str) -> List[str]:
|
|
"""
|
|
Get list of site IDs that have indexed documents for this user.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
|
|
Returns:
|
|
List of site IDs with indexed documents
|
|
"""
|
|
with self.conn.cursor() as cur:
|
|
cur.execute(f"""
|
|
SELECT DISTINCT site_id
|
|
FROM {self.table_prefix}documents
|
|
WHERE user_id = %s
|
|
ORDER BY site_id
|
|
""", (user_id,))
|
|
|
|
return [row[0] for row in cur.fetchall()]
|
|
|
|
def update_document_tags(self, document_id: str, user_id: str, tags: List[str]):
|
|
"""Update tags for a document."""
|
|
with self.conn.cursor() as cur:
|
|
# Verify ownership
|
|
cur.execute(
|
|
f"SELECT user_id FROM {self.table_prefix}documents WHERE document_id = %s",
|
|
(document_id,)
|
|
)
|
|
row = cur.fetchone()
|
|
|
|
if not row:
|
|
raise ValueError("Document not found")
|
|
|
|
if row[0] != user_id:
|
|
raise ValueError("Access denied")
|
|
|
|
# Update tags
|
|
self._update_tags(cur, document_id, tags)
|
|
|
|
# Update timestamp
|
|
cur.execute(
|
|
f"UPDATE {self.table_prefix}documents SET updated_at = %s WHERE document_id = %s",
|
|
(datetime.utcnow(), document_id)
|
|
)
|
|
|
|
self.conn.commit()
|
|
|
|
def remove_document(self, document_id: str, user_id: str):
|
|
"""Remove a document and its chunks."""
|
|
with self.conn.cursor() as cur:
|
|
# Verify ownership
|
|
cur.execute(
|
|
f"SELECT user_id FROM {self.table_prefix}documents WHERE document_id = %s",
|
|
(document_id,)
|
|
)
|
|
row = cur.fetchone()
|
|
|
|
if not row:
|
|
return
|
|
|
|
if row[0] != user_id:
|
|
raise ValueError("Access denied")
|
|
|
|
# Delete document (cascades to chunks and tags)
|
|
cur.execute(
|
|
f"DELETE FROM {self.table_prefix}documents WHERE document_id = %s",
|
|
(document_id,)
|
|
)
|
|
|
|
self.conn.commit()
|
|
|
|
def _update_tags(self, cur, document_id: str, tags: List[str]):
|
|
"""Update tags for a document within a transaction."""
|
|
# Delete old tags
|
|
cur.execute(
|
|
f"DELETE FROM {self.table_prefix}document_tags WHERE document_id = %s",
|
|
(document_id,)
|
|
)
|
|
|
|
# Insert new tags
|
|
if tags:
|
|
tag_data = [(document_id, tag.lower()) for tag in tags]
|
|
execute_values(
|
|
cur,
|
|
f"INSERT INTO {self.table_prefix}document_tags (document_id, tag) VALUES %s",
|
|
tag_data
|
|
)
|
|
|
|
def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
|
|
"""Split text into overlapping chunks."""
|
|
if not text:
|
|
return []
|
|
|
|
chunks = []
|
|
start = 0
|
|
text_length = len(text)
|
|
|
|
while start < text_length:
|
|
end = start + chunk_size
|
|
|
|
# Try to break at sentence boundary
|
|
if end < text_length:
|
|
for punct in ['. ', '! ', '? ', '\n\n']:
|
|
last_punct = text.rfind(punct, start, end)
|
|
if last_punct != -1:
|
|
end = last_punct + len(punct)
|
|
break
|
|
|
|
chunk = text[start:end].strip()
|
|
if chunk:
|
|
chunks.append(chunk)
|
|
|
|
start = end - overlap if end < text_length else text_length
|
|
|
|
return chunks
|
|
|
|
def _generate_document_id(self, user_id: str, site_id: str, file_path: str) -> str:
|
|
"""Generate unique document ID."""
|
|
combined = f"{user_id}:{site_id}:{file_path}"
|
|
return hashlib.sha256(combined.encode()).hexdigest()
|
|
|
|
def close(self):
|
|
"""Close database connection."""
|
|
if self.conn:
|
|
self.conn.close()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.close()
|
|
|
|
|
|
def create_embedding_provider(provider: str = "ollama", **kwargs) -> EmbeddingProvider:
|
|
"""
|
|
Factory function to create embedding provider.
|
|
|
|
Args:
|
|
provider: Provider name ("ollama" or "openai")
|
|
**kwargs: Provider-specific configuration
|
|
|
|
Returns:
|
|
EmbeddingProvider instance
|
|
"""
|
|
if provider.lower() == "ollama":
|
|
return OllamaEmbeddings(
|
|
base_url=kwargs.get("base_url", os.getenv("OLLAMA_URL", "http://localhost:11434")),
|
|
model=kwargs.get("model", os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text"))
|
|
)
|
|
elif provider.lower() == "openai":
|
|
api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY"))
|
|
if not api_key:
|
|
raise ValueError("OpenAI API key required")
|
|
return OpenAIEmbeddings(
|
|
api_key=api_key,
|
|
model=kwargs.get("model", "text-embedding-3-small")
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown embedding provider: {provider}")
|