""" PostgreSQL Vector Store with pgvector for Multi-Document Chat with RAG Persistent vector storage using PostgreSQL with pgvector extension. """ import os import hashlib from typing import List, Dict, Any, Optional from dataclasses import dataclass, asdict from datetime import datetime import psycopg2 from psycopg2.extras import RealDictCursor, execute_values from pgvector.psycopg2 import register_vector @dataclass class DocumentChunk: """A chunk of a document with metadata.""" chunk_id: str document_id: str content: str chunk_index: int embedding: Optional[List[float]] = None metadata: Optional[Dict[str, Any]] = None @dataclass class Document: """Document metadata with tags.""" document_id: str user_id: str site_id: str file_path: str filename: str tags: List[str] content_hash: str created_at: str updated_at: str chunk_count: int metadata: Optional[Dict[str, Any]] = None class EmbeddingProvider: """Base class for embedding providers.""" def embed_text(self, text: str) -> List[float]: """Generate embedding for text.""" raise NotImplementedError def embed_batch(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for multiple texts.""" return [self.embed_text(text) for text in texts] class OllamaEmbeddings(EmbeddingProvider): """Ollama embeddings using local models.""" def __init__(self, base_url: str = "http://localhost:11434", model: str = "nomic-embed-text"): """ Initialize Ollama embeddings. Args: base_url: Ollama server URL model: Embedding model (e.g., nomic-embed-text, mxbai-embed-large) """ import requests self.base_url = base_url.rstrip('/') self.model = model self.requests = requests def embed_text(self, text: str) -> List[float]: """Generate embedding using Ollama.""" response = self.requests.post( f"{self.base_url}/api/embeddings", json={ "model": self.model, "prompt": text } ) response.raise_for_status() return response.json()["embedding"] class OpenAIEmbeddings(EmbeddingProvider): """OpenAI embeddings.""" def __init__(self, api_key: str, model: str = "text-embedding-3-small"): """ Initialize OpenAI embeddings. Args: api_key: OpenAI API key model: Embedding model (text-embedding-3-small or text-embedding-3-large) """ import requests self.api_key = api_key self.model = model self.requests = requests def embed_text(self, text: str) -> List[float]: """Generate embedding using OpenAI.""" response = self.requests.post( "https://api.openai.com/v1/embeddings", headers={ "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" }, json={ "input": text, "model": self.model } ) response.raise_for_status() return response.json()["data"][0]["embedding"] class PostgreSQLVectorStore: """ PostgreSQL vector store with pgvector extension. Provides persistent storage for document embeddings with tag-based organization. """ def __init__( self, embedding_provider: EmbeddingProvider, connection_string: Optional[str] = None, table_prefix: str = "" ): """ Initialize PostgreSQL vector store. Args: embedding_provider: Provider for generating embeddings connection_string: PostgreSQL connection string (uses env vars if not provided) table_prefix: Prefix for table names (e.g., "dev_" or "prod_") """ self.embedding_provider = embedding_provider self.table_prefix = table_prefix # Use connection string or build from env vars if connection_string: self.connection_string = connection_string else: self.connection_string = ( f"host={os.getenv('POSTGRES_HOST', 'localhost')} " f"port={os.getenv('POSTGRES_PORT', '5432')} " f"dbname={os.getenv('POSTGRES_DB', 'sharepoint_vectors')} " f"user={os.getenv('POSTGRES_USER', 'postgres')} " f"password={os.getenv('POSTGRES_PASSWORD', 'postgres')}" ) # Initialize database connection self.conn = psycopg2.connect(self.connection_string) self.conn.autocommit = False # Create tables and enable extension if they don't exist self._create_tables() # Register pgvector type (must be after extension is enabled) register_vector(self.conn) def _create_tables(self): """Create tables for documents, chunks, and tags.""" with self.conn.cursor() as cur: # Enable pgvector extension cur.execute("CREATE EXTENSION IF NOT EXISTS vector") # Documents table cur.execute(f""" CREATE TABLE IF NOT EXISTS {self.table_prefix}documents ( document_id VARCHAR(64) PRIMARY KEY, user_id VARCHAR(255) NOT NULL, site_id VARCHAR(255) NOT NULL, file_path TEXT NOT NULL, filename VARCHAR(255) NOT NULL, content_hash VARCHAR(64) NOT NULL, created_at TIMESTAMP NOT NULL, updated_at TIMESTAMP NOT NULL, chunk_count INTEGER NOT NULL, metadata JSONB, UNIQUE(user_id, site_id, file_path) ) """) # Document chunks table with vector embeddings # Assuming 768 dimensions for nomic-embed-text (adjust based on your model) cur.execute(f""" CREATE TABLE IF NOT EXISTS {self.table_prefix}document_chunks ( chunk_id VARCHAR(128) PRIMARY KEY, document_id VARCHAR(64) NOT NULL REFERENCES {self.table_prefix}documents(document_id) ON DELETE CASCADE, content TEXT NOT NULL, chunk_index INTEGER NOT NULL, embedding vector(768), metadata JSONB, UNIQUE(document_id, chunk_index) ) """) # Tags table (many-to-many with documents) cur.execute(f""" CREATE TABLE IF NOT EXISTS {self.table_prefix}document_tags ( document_id VARCHAR(64) NOT NULL REFERENCES {self.table_prefix}documents(document_id) ON DELETE CASCADE, tag VARCHAR(100) NOT NULL, PRIMARY KEY(document_id, tag) ) """) # Create indexes for performance cur.execute(f""" CREATE INDEX IF NOT EXISTS idx_documents_user_id ON {self.table_prefix}documents(user_id) """) cur.execute(f""" CREATE INDEX IF NOT EXISTS idx_document_tags_tag ON {self.table_prefix}document_tags(tag) """) # Create IVFFlat index for vector similarity search (after some data is inserted) # This is commented out - create it manually after inserting ~1000+ vectors: # CREATE INDEX ON document_chunks USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100); self.conn.commit() def add_document( self, user_id: str, site_id: str, file_path: str, filename: str, content: str, tags: List[str], chunk_size: int = 1000, chunk_overlap: int = 200, metadata: Optional[Dict[str, Any]] = None ) -> str: """ Add a document to the vector store with chunking and embeddings. Args: user_id: User ID site_id: SharePoint site ID file_path: File path filename: Filename content: Document content tags: List of tags (e.g., ["HR", "SALES", "Q4-2024"]) chunk_size: Size of each chunk in characters chunk_overlap: Overlap between chunks metadata: Additional metadata Returns: document_id """ # Generate document ID document_id = self._generate_document_id(user_id, site_id, file_path) # Calculate content hash content_hash = hashlib.sha256(content.encode()).hexdigest() with self.conn.cursor() as cur: # Check if document already exists cur.execute( f"SELECT content_hash FROM {self.table_prefix}documents WHERE document_id = %s", (document_id,) ) existing = cur.fetchone() if existing and existing[0] == content_hash: # Content hasn't changed, just update tags if different cur.execute( f"SELECT tag FROM {self.table_prefix}document_tags WHERE document_id = %s", (document_id,) ) existing_tags = [row[0] for row in cur.fetchall()] if set(tags) != set(existing_tags): self._update_tags(cur, document_id, tags) cur.execute( f"UPDATE {self.table_prefix}documents SET updated_at = %s WHERE document_id = %s", (datetime.utcnow(), document_id) ) self.conn.commit() return document_id # Delete old chunks if content changed if existing: cur.execute( f"DELETE FROM {self.table_prefix}document_chunks WHERE document_id = %s", (document_id,) ) # Chunk the document chunks = self._chunk_text(content, chunk_size, chunk_overlap) # Generate embeddings for all chunks embeddings = self.embedding_provider.embed_batch(chunks) # Insert or update document now = datetime.utcnow() if existing: cur.execute(f""" UPDATE {self.table_prefix}documents SET content_hash = %s, updated_at = %s, chunk_count = %s, metadata = %s WHERE document_id = %s """, (content_hash, now, len(chunks), psycopg2.extras.Json(metadata), document_id)) else: cur.execute(f""" INSERT INTO {self.table_prefix}documents (document_id, user_id, site_id, file_path, filename, content_hash, created_at, updated_at, chunk_count, metadata) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, (document_id, user_id, site_id, file_path, filename, content_hash, now, now, len(chunks), psycopg2.extras.Json(metadata))) # Insert chunks with embeddings chunk_data = [ ( f"{document_id}_chunk_{idx}", document_id, chunk_text, idx, embedding, psycopg2.extras.Json(metadata) ) for idx, (chunk_text, embedding) in enumerate(zip(chunks, embeddings)) ] execute_values( cur, f""" INSERT INTO {self.table_prefix}document_chunks (chunk_id, document_id, content, chunk_index, embedding, metadata) VALUES %s """, chunk_data ) # Update tags self._update_tags(cur, document_id, tags) self.conn.commit() return document_id def search( self, user_id: str, query: str, tags: Optional[List[str]] = None, top_k: int = 5 ) -> List[Dict[str, Any]]: """ Search for relevant document chunks using vector similarity. Args: user_id: User ID (for access control) query: Search query tags: Optional list of tags to filter by top_k: Number of results to return Returns: List of relevant chunks with similarity scores """ # Generate query embedding query_embedding = self.embedding_provider.embed_text(query) with self.conn.cursor(cursor_factory=RealDictCursor) as cur: # Build query based on whether tags are specified if tags: # Filter by tags using JOIN tag_placeholders = ','.join(['%s'] * len(tags)) cur.execute(f""" SELECT c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata as chunk_metadata, d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at, d.chunk_count, d.content_hash, d.metadata as doc_metadata, ARRAY_AGG(dt.tag) as tags, 1 - (c.embedding <=> %s::vector) as similarity FROM {self.table_prefix}document_chunks c JOIN {self.table_prefix}documents d ON c.document_id = d.document_id JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id WHERE d.user_id = %s AND LOWER(dt.tag) IN ({tag_placeholders}) GROUP BY c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata, d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at, d.chunk_count, d.content_hash, d.metadata, c.embedding ORDER BY similarity DESC LIMIT %s """, [query_embedding, user_id] + [tag.lower() for tag in tags] + [top_k]) else: # All user's documents cur.execute(f""" SELECT c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata as chunk_metadata, d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at, d.chunk_count, d.content_hash, d.metadata as doc_metadata, ARRAY_AGG(dt.tag) as tags, 1 - (c.embedding <=> %s::vector) as similarity FROM {self.table_prefix}document_chunks c JOIN {self.table_prefix}documents d ON c.document_id = d.document_id LEFT JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id WHERE d.user_id = %s GROUP BY c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata, d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at, d.chunk_count, d.content_hash, d.metadata, c.embedding ORDER BY similarity DESC LIMIT %s """, [query_embedding, user_id, top_k]) rows = cur.fetchall() # Convert to result format results = [] for row in rows: chunk = DocumentChunk( chunk_id=row['chunk_id'], document_id=row['document_id'], content=row['content'], chunk_index=row['chunk_index'], metadata=row['chunk_metadata'] ) document = Document( document_id=row['document_id'], user_id=row['user_id'], site_id=row['site_id'], file_path=row['file_path'], filename=row['filename'], tags=row['tags'] or [], content_hash=row['content_hash'], created_at=row['created_at'].isoformat(), updated_at=row['updated_at'].isoformat(), chunk_count=row['chunk_count'], metadata=row['doc_metadata'] ) results.append({ "chunk": chunk, "document": document, "similarity": float(row['similarity']) }) return results def get_documents_by_tags(self, user_id: str, tags: List[str]) -> List[Document]: """Get all documents with specific tags.""" with self.conn.cursor(cursor_factory=RealDictCursor) as cur: tag_placeholders = ','.join(['%s'] * len(tags)) cur.execute(f""" SELECT DISTINCT d.document_id, d.user_id, d.site_id, d.file_path, d.filename, d.content_hash, d.created_at, d.updated_at, d.chunk_count, d.metadata, ARRAY_AGG(dt.tag) as tags FROM {self.table_prefix}documents d JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id WHERE d.user_id = %s AND LOWER(dt.tag) IN ({tag_placeholders}) GROUP BY d.document_id, d.user_id, d.site_id, d.file_path, d.filename, d.content_hash, d.created_at, d.updated_at, d.chunk_count, d.metadata """, [user_id] + [tag.lower() for tag in tags]) rows = cur.fetchall() return [ Document( document_id=row['document_id'], user_id=row['user_id'], site_id=row['site_id'], file_path=row['file_path'], filename=row['filename'], tags=row['tags'] or [], content_hash=row['content_hash'], created_at=row['created_at'].isoformat(), updated_at=row['updated_at'].isoformat(), chunk_count=row['chunk_count'], metadata=row['metadata'] ) for row in rows ] def list_tags(self, user_id: str) -> Dict[str, int]: """List all tags for user with document counts.""" with self.conn.cursor() as cur: cur.execute(f""" SELECT dt.tag, COUNT(DISTINCT dt.document_id) as count FROM {self.table_prefix}document_tags dt JOIN {self.table_prefix}documents d ON dt.document_id = d.document_id WHERE d.user_id = %s GROUP BY dt.tag ORDER BY count DESC, dt.tag """, (user_id,)) return {row[0]: row[1] for row in cur.fetchall()} def get_indexed_sites(self, user_id: str) -> List[str]: """ Get list of site IDs that have indexed documents for this user. Args: user_id: User ID Returns: List of site IDs with indexed documents """ with self.conn.cursor() as cur: cur.execute(f""" SELECT DISTINCT site_id FROM {self.table_prefix}documents WHERE user_id = %s ORDER BY site_id """, (user_id,)) return [row[0] for row in cur.fetchall()] def update_document_tags(self, document_id: str, user_id: str, tags: List[str]): """Update tags for a document.""" with self.conn.cursor() as cur: # Verify ownership cur.execute( f"SELECT user_id FROM {self.table_prefix}documents WHERE document_id = %s", (document_id,) ) row = cur.fetchone() if not row: raise ValueError("Document not found") if row[0] != user_id: raise ValueError("Access denied") # Update tags self._update_tags(cur, document_id, tags) # Update timestamp cur.execute( f"UPDATE {self.table_prefix}documents SET updated_at = %s WHERE document_id = %s", (datetime.utcnow(), document_id) ) self.conn.commit() def remove_document(self, document_id: str, user_id: str): """Remove a document and its chunks.""" with self.conn.cursor() as cur: # Verify ownership cur.execute( f"SELECT user_id FROM {self.table_prefix}documents WHERE document_id = %s", (document_id,) ) row = cur.fetchone() if not row: return if row[0] != user_id: raise ValueError("Access denied") # Delete document (cascades to chunks and tags) cur.execute( f"DELETE FROM {self.table_prefix}documents WHERE document_id = %s", (document_id,) ) self.conn.commit() def _update_tags(self, cur, document_id: str, tags: List[str]): """Update tags for a document within a transaction.""" # Delete old tags cur.execute( f"DELETE FROM {self.table_prefix}document_tags WHERE document_id = %s", (document_id,) ) # Insert new tags if tags: tag_data = [(document_id, tag.lower()) for tag in tags] execute_values( cur, f"INSERT INTO {self.table_prefix}document_tags (document_id, tag) VALUES %s", tag_data ) def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]: """Split text into overlapping chunks.""" if not text: return [] chunks = [] start = 0 text_length = len(text) while start < text_length: end = start + chunk_size # Try to break at sentence boundary if end < text_length: for punct in ['. ', '! ', '? ', '\n\n']: last_punct = text.rfind(punct, start, end) if last_punct != -1: end = last_punct + len(punct) break chunk = text[start:end].strip() if chunk: chunks.append(chunk) start = end - overlap if end < text_length else text_length return chunks def _generate_document_id(self, user_id: str, site_id: str, file_path: str) -> str: """Generate unique document ID.""" combined = f"{user_id}:{site_id}:{file_path}" return hashlib.sha256(combined.encode()).hexdigest() def close(self): """Close database connection.""" if self.conn: self.conn.close() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def create_embedding_provider(provider: str = "ollama", **kwargs) -> EmbeddingProvider: """ Factory function to create embedding provider. Args: provider: Provider name ("ollama" or "openai") **kwargs: Provider-specific configuration Returns: EmbeddingProvider instance """ if provider.lower() == "ollama": return OllamaEmbeddings( base_url=kwargs.get("base_url", os.getenv("OLLAMA_URL", "http://localhost:11434")), model=kwargs.get("model", os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text")) ) elif provider.lower() == "openai": api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY")) if not api_key: raise ValueError("OpenAI API key required") return OpenAIEmbeddings( api_key=api_key, model=kwargs.get("model", "text-embedding-3-small") ) else: raise ValueError(f"Unknown embedding provider: {provider}")