tf_sharepoint_integration/vector_store_postgres.py
Daniel Grozdanovic bcd0f8a227
Some checks failed
CI - SharePoint Plugin with SonarQube / Test and SonarQube Analysis (push) Has been cancelled
Initial commit: SharePoint connector and ToothFairyAI integration
2026-02-22 17:58:45 +02:00

653 lines
24 KiB
Python

"""
PostgreSQL Vector Store with pgvector for Multi-Document Chat with RAG
Persistent vector storage using PostgreSQL with pgvector extension.
"""
import os
import hashlib
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import psycopg2
from psycopg2.extras import RealDictCursor, execute_values
from pgvector.psycopg2 import register_vector
@dataclass
class DocumentChunk:
"""A chunk of a document with metadata."""
chunk_id: str
document_id: str
content: str
chunk_index: int
embedding: Optional[List[float]] = None
metadata: Optional[Dict[str, Any]] = None
@dataclass
class Document:
"""Document metadata with tags."""
document_id: str
user_id: str
site_id: str
file_path: str
filename: str
tags: List[str]
content_hash: str
created_at: str
updated_at: str
chunk_count: int
metadata: Optional[Dict[str, Any]] = None
class EmbeddingProvider:
"""Base class for embedding providers."""
def embed_text(self, text: str) -> List[float]:
"""Generate embedding for text."""
raise NotImplementedError
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts."""
return [self.embed_text(text) for text in texts]
class OllamaEmbeddings(EmbeddingProvider):
"""Ollama embeddings using local models."""
def __init__(self, base_url: str = "http://localhost:11434", model: str = "nomic-embed-text"):
"""
Initialize Ollama embeddings.
Args:
base_url: Ollama server URL
model: Embedding model (e.g., nomic-embed-text, mxbai-embed-large)
"""
import requests
self.base_url = base_url.rstrip('/')
self.model = model
self.requests = requests
def embed_text(self, text: str) -> List[float]:
"""Generate embedding using Ollama."""
response = self.requests.post(
f"{self.base_url}/api/embeddings",
json={
"model": self.model,
"prompt": text
}
)
response.raise_for_status()
return response.json()["embedding"]
class OpenAIEmbeddings(EmbeddingProvider):
"""OpenAI embeddings."""
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
"""
Initialize OpenAI embeddings.
Args:
api_key: OpenAI API key
model: Embedding model (text-embedding-3-small or text-embedding-3-large)
"""
import requests
self.api_key = api_key
self.model = model
self.requests = requests
def embed_text(self, text: str) -> List[float]:
"""Generate embedding using OpenAI."""
response = self.requests.post(
"https://api.openai.com/v1/embeddings",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"input": text,
"model": self.model
}
)
response.raise_for_status()
return response.json()["data"][0]["embedding"]
class PostgreSQLVectorStore:
"""
PostgreSQL vector store with pgvector extension.
Provides persistent storage for document embeddings with tag-based organization.
"""
def __init__(
self,
embedding_provider: EmbeddingProvider,
connection_string: Optional[str] = None,
table_prefix: str = ""
):
"""
Initialize PostgreSQL vector store.
Args:
embedding_provider: Provider for generating embeddings
connection_string: PostgreSQL connection string (uses env vars if not provided)
table_prefix: Prefix for table names (e.g., "dev_" or "prod_")
"""
self.embedding_provider = embedding_provider
self.table_prefix = table_prefix
# Use connection string or build from env vars
if connection_string:
self.connection_string = connection_string
else:
self.connection_string = (
f"host={os.getenv('POSTGRES_HOST', 'localhost')} "
f"port={os.getenv('POSTGRES_PORT', '5432')} "
f"dbname={os.getenv('POSTGRES_DB', 'sharepoint_vectors')} "
f"user={os.getenv('POSTGRES_USER', 'postgres')} "
f"password={os.getenv('POSTGRES_PASSWORD', 'postgres')}"
)
# Initialize database connection
self.conn = psycopg2.connect(self.connection_string)
self.conn.autocommit = False
# Create tables and enable extension if they don't exist
self._create_tables()
# Register pgvector type (must be after extension is enabled)
register_vector(self.conn)
def _create_tables(self):
"""Create tables for documents, chunks, and tags."""
with self.conn.cursor() as cur:
# Enable pgvector extension
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
# Documents table
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_prefix}documents (
document_id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(255) NOT NULL,
site_id VARCHAR(255) NOT NULL,
file_path TEXT NOT NULL,
filename VARCHAR(255) NOT NULL,
content_hash VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL,
updated_at TIMESTAMP NOT NULL,
chunk_count INTEGER NOT NULL,
metadata JSONB,
UNIQUE(user_id, site_id, file_path)
)
""")
# Document chunks table with vector embeddings
# Assuming 768 dimensions for nomic-embed-text (adjust based on your model)
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_prefix}document_chunks (
chunk_id VARCHAR(128) PRIMARY KEY,
document_id VARCHAR(64) NOT NULL REFERENCES {self.table_prefix}documents(document_id) ON DELETE CASCADE,
content TEXT NOT NULL,
chunk_index INTEGER NOT NULL,
embedding vector(768),
metadata JSONB,
UNIQUE(document_id, chunk_index)
)
""")
# Tags table (many-to-many with documents)
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_prefix}document_tags (
document_id VARCHAR(64) NOT NULL REFERENCES {self.table_prefix}documents(document_id) ON DELETE CASCADE,
tag VARCHAR(100) NOT NULL,
PRIMARY KEY(document_id, tag)
)
""")
# Create indexes for performance
cur.execute(f"""
CREATE INDEX IF NOT EXISTS idx_documents_user_id
ON {self.table_prefix}documents(user_id)
""")
cur.execute(f"""
CREATE INDEX IF NOT EXISTS idx_document_tags_tag
ON {self.table_prefix}document_tags(tag)
""")
# Create IVFFlat index for vector similarity search (after some data is inserted)
# This is commented out - create it manually after inserting ~1000+ vectors:
# CREATE INDEX ON document_chunks USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
self.conn.commit()
def add_document(
self,
user_id: str,
site_id: str,
file_path: str,
filename: str,
content: str,
tags: List[str],
chunk_size: int = 1000,
chunk_overlap: int = 200,
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""
Add a document to the vector store with chunking and embeddings.
Args:
user_id: User ID
site_id: SharePoint site ID
file_path: File path
filename: Filename
content: Document content
tags: List of tags (e.g., ["HR", "SALES", "Q4-2024"])
chunk_size: Size of each chunk in characters
chunk_overlap: Overlap between chunks
metadata: Additional metadata
Returns:
document_id
"""
# Generate document ID
document_id = self._generate_document_id(user_id, site_id, file_path)
# Calculate content hash
content_hash = hashlib.sha256(content.encode()).hexdigest()
with self.conn.cursor() as cur:
# Check if document already exists
cur.execute(
f"SELECT content_hash FROM {self.table_prefix}documents WHERE document_id = %s",
(document_id,)
)
existing = cur.fetchone()
if existing and existing[0] == content_hash:
# Content hasn't changed, just update tags if different
cur.execute(
f"SELECT tag FROM {self.table_prefix}document_tags WHERE document_id = %s",
(document_id,)
)
existing_tags = [row[0] for row in cur.fetchall()]
if set(tags) != set(existing_tags):
self._update_tags(cur, document_id, tags)
cur.execute(
f"UPDATE {self.table_prefix}documents SET updated_at = %s WHERE document_id = %s",
(datetime.utcnow(), document_id)
)
self.conn.commit()
return document_id
# Delete old chunks if content changed
if existing:
cur.execute(
f"DELETE FROM {self.table_prefix}document_chunks WHERE document_id = %s",
(document_id,)
)
# Chunk the document
chunks = self._chunk_text(content, chunk_size, chunk_overlap)
# Generate embeddings for all chunks
embeddings = self.embedding_provider.embed_batch(chunks)
# Insert or update document
now = datetime.utcnow()
if existing:
cur.execute(f"""
UPDATE {self.table_prefix}documents
SET content_hash = %s, updated_at = %s, chunk_count = %s, metadata = %s
WHERE document_id = %s
""", (content_hash, now, len(chunks), psycopg2.extras.Json(metadata), document_id))
else:
cur.execute(f"""
INSERT INTO {self.table_prefix}documents
(document_id, user_id, site_id, file_path, filename, content_hash, created_at, updated_at, chunk_count, metadata)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (document_id, user_id, site_id, file_path, filename, content_hash, now, now, len(chunks), psycopg2.extras.Json(metadata)))
# Insert chunks with embeddings
chunk_data = [
(
f"{document_id}_chunk_{idx}",
document_id,
chunk_text,
idx,
embedding,
psycopg2.extras.Json(metadata)
)
for idx, (chunk_text, embedding) in enumerate(zip(chunks, embeddings))
]
execute_values(
cur,
f"""
INSERT INTO {self.table_prefix}document_chunks
(chunk_id, document_id, content, chunk_index, embedding, metadata)
VALUES %s
""",
chunk_data
)
# Update tags
self._update_tags(cur, document_id, tags)
self.conn.commit()
return document_id
def search(
self,
user_id: str,
query: str,
tags: Optional[List[str]] = None,
top_k: int = 5
) -> List[Dict[str, Any]]:
"""
Search for relevant document chunks using vector similarity.
Args:
user_id: User ID (for access control)
query: Search query
tags: Optional list of tags to filter by
top_k: Number of results to return
Returns:
List of relevant chunks with similarity scores
"""
# Generate query embedding
query_embedding = self.embedding_provider.embed_text(query)
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
# Build query based on whether tags are specified
if tags:
# Filter by tags using JOIN
tag_placeholders = ','.join(['%s'] * len(tags))
cur.execute(f"""
SELECT
c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata as chunk_metadata,
d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at,
d.chunk_count, d.content_hash, d.metadata as doc_metadata,
ARRAY_AGG(dt.tag) as tags,
1 - (c.embedding <=> %s::vector) as similarity
FROM {self.table_prefix}document_chunks c
JOIN {self.table_prefix}documents d ON c.document_id = d.document_id
JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
WHERE d.user_id = %s AND LOWER(dt.tag) IN ({tag_placeholders})
GROUP BY c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata,
d.user_id, d.site_id, d.file_path, d.filename, d.created_at,
d.updated_at, d.chunk_count, d.content_hash, d.metadata, c.embedding
ORDER BY similarity DESC
LIMIT %s
""", [query_embedding, user_id] + [tag.lower() for tag in tags] + [top_k])
else:
# All user's documents
cur.execute(f"""
SELECT
c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata as chunk_metadata,
d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at,
d.chunk_count, d.content_hash, d.metadata as doc_metadata,
ARRAY_AGG(dt.tag) as tags,
1 - (c.embedding <=> %s::vector) as similarity
FROM {self.table_prefix}document_chunks c
JOIN {self.table_prefix}documents d ON c.document_id = d.document_id
LEFT JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
WHERE d.user_id = %s
GROUP BY c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata,
d.user_id, d.site_id, d.file_path, d.filename, d.created_at,
d.updated_at, d.chunk_count, d.content_hash, d.metadata, c.embedding
ORDER BY similarity DESC
LIMIT %s
""", [query_embedding, user_id, top_k])
rows = cur.fetchall()
# Convert to result format
results = []
for row in rows:
chunk = DocumentChunk(
chunk_id=row['chunk_id'],
document_id=row['document_id'],
content=row['content'],
chunk_index=row['chunk_index'],
metadata=row['chunk_metadata']
)
document = Document(
document_id=row['document_id'],
user_id=row['user_id'],
site_id=row['site_id'],
file_path=row['file_path'],
filename=row['filename'],
tags=row['tags'] or [],
content_hash=row['content_hash'],
created_at=row['created_at'].isoformat(),
updated_at=row['updated_at'].isoformat(),
chunk_count=row['chunk_count'],
metadata=row['doc_metadata']
)
results.append({
"chunk": chunk,
"document": document,
"similarity": float(row['similarity'])
})
return results
def get_documents_by_tags(self, user_id: str, tags: List[str]) -> List[Document]:
"""Get all documents with specific tags."""
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
tag_placeholders = ','.join(['%s'] * len(tags))
cur.execute(f"""
SELECT DISTINCT
d.document_id, d.user_id, d.site_id, d.file_path, d.filename,
d.content_hash, d.created_at, d.updated_at, d.chunk_count, d.metadata,
ARRAY_AGG(dt.tag) as tags
FROM {self.table_prefix}documents d
JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
WHERE d.user_id = %s AND LOWER(dt.tag) IN ({tag_placeholders})
GROUP BY d.document_id, d.user_id, d.site_id, d.file_path, d.filename,
d.content_hash, d.created_at, d.updated_at, d.chunk_count, d.metadata
""", [user_id] + [tag.lower() for tag in tags])
rows = cur.fetchall()
return [
Document(
document_id=row['document_id'],
user_id=row['user_id'],
site_id=row['site_id'],
file_path=row['file_path'],
filename=row['filename'],
tags=row['tags'] or [],
content_hash=row['content_hash'],
created_at=row['created_at'].isoformat(),
updated_at=row['updated_at'].isoformat(),
chunk_count=row['chunk_count'],
metadata=row['metadata']
)
for row in rows
]
def list_tags(self, user_id: str) -> Dict[str, int]:
"""List all tags for user with document counts."""
with self.conn.cursor() as cur:
cur.execute(f"""
SELECT dt.tag, COUNT(DISTINCT dt.document_id) as count
FROM {self.table_prefix}document_tags dt
JOIN {self.table_prefix}documents d ON dt.document_id = d.document_id
WHERE d.user_id = %s
GROUP BY dt.tag
ORDER BY count DESC, dt.tag
""", (user_id,))
return {row[0]: row[1] for row in cur.fetchall()}
def get_indexed_sites(self, user_id: str) -> List[str]:
"""
Get list of site IDs that have indexed documents for this user.
Args:
user_id: User ID
Returns:
List of site IDs with indexed documents
"""
with self.conn.cursor() as cur:
cur.execute(f"""
SELECT DISTINCT site_id
FROM {self.table_prefix}documents
WHERE user_id = %s
ORDER BY site_id
""", (user_id,))
return [row[0] for row in cur.fetchall()]
def update_document_tags(self, document_id: str, user_id: str, tags: List[str]):
"""Update tags for a document."""
with self.conn.cursor() as cur:
# Verify ownership
cur.execute(
f"SELECT user_id FROM {self.table_prefix}documents WHERE document_id = %s",
(document_id,)
)
row = cur.fetchone()
if not row:
raise ValueError("Document not found")
if row[0] != user_id:
raise ValueError("Access denied")
# Update tags
self._update_tags(cur, document_id, tags)
# Update timestamp
cur.execute(
f"UPDATE {self.table_prefix}documents SET updated_at = %s WHERE document_id = %s",
(datetime.utcnow(), document_id)
)
self.conn.commit()
def remove_document(self, document_id: str, user_id: str):
"""Remove a document and its chunks."""
with self.conn.cursor() as cur:
# Verify ownership
cur.execute(
f"SELECT user_id FROM {self.table_prefix}documents WHERE document_id = %s",
(document_id,)
)
row = cur.fetchone()
if not row:
return
if row[0] != user_id:
raise ValueError("Access denied")
# Delete document (cascades to chunks and tags)
cur.execute(
f"DELETE FROM {self.table_prefix}documents WHERE document_id = %s",
(document_id,)
)
self.conn.commit()
def _update_tags(self, cur, document_id: str, tags: List[str]):
"""Update tags for a document within a transaction."""
# Delete old tags
cur.execute(
f"DELETE FROM {self.table_prefix}document_tags WHERE document_id = %s",
(document_id,)
)
# Insert new tags
if tags:
tag_data = [(document_id, tag.lower()) for tag in tags]
execute_values(
cur,
f"INSERT INTO {self.table_prefix}document_tags (document_id, tag) VALUES %s",
tag_data
)
def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
"""Split text into overlapping chunks."""
if not text:
return []
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = start + chunk_size
# Try to break at sentence boundary
if end < text_length:
for punct in ['. ', '! ', '? ', '\n\n']:
last_punct = text.rfind(punct, start, end)
if last_punct != -1:
end = last_punct + len(punct)
break
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
start = end - overlap if end < text_length else text_length
return chunks
def _generate_document_id(self, user_id: str, site_id: str, file_path: str) -> str:
"""Generate unique document ID."""
combined = f"{user_id}:{site_id}:{file_path}"
return hashlib.sha256(combined.encode()).hexdigest()
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def create_embedding_provider(provider: str = "ollama", **kwargs) -> EmbeddingProvider:
"""
Factory function to create embedding provider.
Args:
provider: Provider name ("ollama" or "openai")
**kwargs: Provider-specific configuration
Returns:
EmbeddingProvider instance
"""
if provider.lower() == "ollama":
return OllamaEmbeddings(
base_url=kwargs.get("base_url", os.getenv("OLLAMA_URL", "http://localhost:11434")),
model=kwargs.get("model", os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text"))
)
elif provider.lower() == "openai":
api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY"))
if not api_key:
raise ValueError("OpenAI API key required")
return OpenAIEmbeddings(
api_key=api_key,
model=kwargs.get("model", "text-embedding-3-small")
)
else:
raise ValueError(f"Unknown embedding provider: {provider}")