Daniel Grozdanovic bcd0f8a227
Some checks failed
CI - SharePoint Plugin with SonarQube / Test and SonarQube Analysis (push) Has been cancelled
Initial commit: SharePoint connector and ToothFairyAI integration
2026-02-22 17:58:45 +02:00

478 lines
16 KiB
Python

"""
Vector Store for Multi-Document Chat with RAG (Retrieval-Augmented Generation)
Supports document embeddings, tagging, and semantic search for chatting with
multiple documents at once.
"""
import os
import json
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import hashlib
@dataclass
class DocumentChunk:
"""A chunk of a document with metadata."""
chunk_id: str
document_id: str
content: str
chunk_index: int
embedding: Optional[List[float]] = None
metadata: Optional[Dict[str, Any]] = None
@dataclass
class Document:
"""Document metadata with tags."""
document_id: str
user_id: str
site_id: str
file_path: str
filename: str
tags: List[str]
content_hash: str
created_at: str
updated_at: str
chunk_count: int
metadata: Optional[Dict[str, Any]] = None
class EmbeddingProvider:
"""Base class for embedding providers."""
def embed_text(self, text: str) -> List[float]:
"""Generate embedding for text."""
raise NotImplementedError
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts."""
return [self.embed_text(text) for text in texts]
class OllamaEmbeddings(EmbeddingProvider):
"""Ollama embeddings using local models."""
def __init__(self, base_url: str = "http://localhost:11434", model: str = "nomic-embed-text"):
"""
Initialize Ollama embeddings.
Args:
base_url: Ollama server URL
model: Embedding model (e.g., nomic-embed-text, mxbai-embed-large)
"""
import requests
self.base_url = base_url.rstrip('/')
self.model = model
self.requests = requests
def embed_text(self, text: str) -> List[float]:
"""Generate embedding using Ollama."""
response = self.requests.post(
f"{self.base_url}/api/embeddings",
json={
"model": self.model,
"prompt": text
}
)
response.raise_for_status()
return response.json()["embedding"]
class OpenAIEmbeddings(EmbeddingProvider):
"""OpenAI embeddings."""
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
"""
Initialize OpenAI embeddings.
Args:
api_key: OpenAI API key
model: Embedding model (text-embedding-3-small or text-embedding-3-large)
"""
import requests
self.api_key = api_key
self.model = model
self.requests = requests
def embed_text(self, text: str) -> List[float]:
"""Generate embedding using OpenAI."""
response = self.requests.post(
"https://api.openai.com/v1/embeddings",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"input": text,
"model": self.model
}
)
response.raise_for_status()
return response.json()["data"][0]["embedding"]
class InMemoryVectorStore:
"""
Simple in-memory vector store for development.
For production, use a proper vector database like:
- Pinecone
- Weaviate
- Qdrant
- ChromaDB
- pgvector (PostgreSQL extension)
"""
def __init__(self, embedding_provider: EmbeddingProvider):
"""Initialize vector store."""
self.embedding_provider = embedding_provider
self.documents: Dict[str, Document] = {}
self.chunks: Dict[str, DocumentChunk] = {}
self.user_documents: Dict[str, List[str]] = {} # user_id -> [document_ids]
self.tag_index: Dict[str, List[str]] = {} # tag -> [document_ids]
def add_document(
self,
user_id: str,
site_id: str,
file_path: str,
filename: str,
content: str,
tags: List[str],
chunk_size: int = 1000,
chunk_overlap: int = 200,
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""
Add a document to the vector store with chunking and embeddings.
Args:
user_id: User ID
site_id: SharePoint site ID
file_path: File path
filename: Filename
content: Document content
tags: List of tags (e.g., ["HR", "SALES", "Q4-2024"])
chunk_size: Size of each chunk in characters
chunk_overlap: Overlap between chunks
metadata: Additional metadata
Returns:
document_id
"""
# Generate document ID
document_id = self._generate_document_id(user_id, site_id, file_path)
# Calculate content hash
content_hash = hashlib.sha256(content.encode()).hexdigest()
# Check if document already exists and hasn't changed
if document_id in self.documents:
existing_doc = self.documents[document_id]
if existing_doc.content_hash == content_hash:
# Update tags if different
if set(tags) != set(existing_doc.tags):
self._update_tags(document_id, existing_doc.tags, tags)
existing_doc.tags = tags
existing_doc.updated_at = datetime.utcnow().isoformat()
return document_id
else:
# Content changed, remove old chunks
self._remove_document_chunks(document_id)
# Chunk the document
chunks = self._chunk_text(content, chunk_size, chunk_overlap)
# Generate embeddings for all chunks
chunk_texts = [chunk for chunk in chunks]
embeddings = self.embedding_provider.embed_batch(chunk_texts)
# Create document chunks
document_chunks = []
for idx, (chunk_text, embedding) in enumerate(zip(chunk_texts, embeddings)):
chunk_id = f"{document_id}_chunk_{idx}"
chunk = DocumentChunk(
chunk_id=chunk_id,
document_id=document_id,
content=chunk_text,
chunk_index=idx,
embedding=embedding,
metadata=metadata
)
self.chunks[chunk_id] = chunk
document_chunks.append(chunk)
# Create document metadata
now = datetime.utcnow().isoformat()
document = Document(
document_id=document_id,
user_id=user_id,
site_id=site_id,
file_path=file_path,
filename=filename,
tags=tags,
content_hash=content_hash,
created_at=now,
updated_at=now,
chunk_count=len(document_chunks),
metadata=metadata
)
# Store document
self.documents[document_id] = document
# Update user index
if user_id not in self.user_documents:
self.user_documents[user_id] = []
if document_id not in self.user_documents[user_id]:
self.user_documents[user_id].append(document_id)
# Update tag index
for tag in tags:
tag_lower = tag.lower()
if tag_lower not in self.tag_index:
self.tag_index[tag_lower] = []
if document_id not in self.tag_index[tag_lower]:
self.tag_index[tag_lower].append(document_id)
return document_id
def search(
self,
user_id: str,
query: str,
tags: Optional[List[str]] = None,
top_k: int = 5
) -> List[Dict[str, Any]]:
"""
Search for relevant document chunks.
Args:
user_id: User ID (for access control)
query: Search query
tags: Optional list of tags to filter by
top_k: Number of results to return
Returns:
List of relevant chunks with similarity scores
"""
# Get query embedding
query_embedding = self.embedding_provider.embed_text(query)
# Get candidate document IDs
if tags:
# Filter by tags
candidate_doc_ids = set()
for tag in tags:
tag_lower = tag.lower()
if tag_lower in self.tag_index:
candidate_doc_ids.update(self.tag_index[tag_lower])
# Intersect with user's documents
user_doc_ids = set(self.user_documents.get(user_id, []))
candidate_doc_ids = candidate_doc_ids.intersection(user_doc_ids)
else:
# All user's documents
candidate_doc_ids = set(self.user_documents.get(user_id, []))
# Get chunks from candidate documents
candidate_chunks = [
chunk for chunk in self.chunks.values()
if chunk.document_id in candidate_doc_ids
]
# Calculate cosine similarity for each chunk
results = []
for chunk in candidate_chunks:
if chunk.embedding:
similarity = self._cosine_similarity(query_embedding, chunk.embedding)
results.append({
"chunk": chunk,
"document": self.documents[chunk.document_id],
"similarity": similarity
})
# Sort by similarity and return top_k
results.sort(key=lambda x: x["similarity"], reverse=True)
return results[:top_k]
def get_documents_by_tags(self, user_id: str, tags: List[str]) -> List[Document]:
"""Get all documents with specific tags."""
doc_ids = set()
for tag in tags:
tag_lower = tag.lower()
if tag_lower in self.tag_index:
doc_ids.update(self.tag_index[tag_lower])
# Filter by user
user_doc_ids = set(self.user_documents.get(user_id, []))
doc_ids = doc_ids.intersection(user_doc_ids)
return [self.documents[doc_id] for doc_id in doc_ids if doc_id in self.documents]
def list_tags(self, user_id: str) -> Dict[str, int]:
"""List all tags for user with document counts."""
user_doc_ids = set(self.user_documents.get(user_id, []))
tag_counts = {}
for tag, doc_ids in self.tag_index.items():
user_tagged_docs = set(doc_ids).intersection(user_doc_ids)
if user_tagged_docs:
tag_counts[tag] = len(user_tagged_docs)
return tag_counts
def update_document_tags(self, document_id: str, user_id: str, tags: List[str]):
"""Update tags for a document."""
if document_id not in self.documents:
raise ValueError("Document not found")
doc = self.documents[document_id]
if doc.user_id != user_id:
raise ValueError("Access denied")
old_tags = doc.tags
self._update_tags(document_id, old_tags, tags)
doc.tags = tags
doc.updated_at = datetime.utcnow().isoformat()
def remove_document(self, document_id: str, user_id: str):
"""Remove a document and its chunks."""
if document_id not in self.documents:
return
doc = self.documents[document_id]
if doc.user_id != user_id:
raise ValueError("Access denied")
# Remove from tag index
for tag in doc.tags:
tag_lower = tag.lower()
if tag_lower in self.tag_index:
self.tag_index[tag_lower] = [
d for d in self.tag_index[tag_lower] if d != document_id
]
# Remove chunks
self._remove_document_chunks(document_id)
# Remove from user index
if user_id in self.user_documents:
self.user_documents[user_id] = [
d for d in self.user_documents[user_id] if d != document_id
]
# Remove document
del self.documents[document_id]
def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
"""Split text into overlapping chunks."""
if not text:
return []
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = start + chunk_size
# Try to break at sentence boundary
if end < text_length:
# Look for sentence endings
for punct in ['. ', '! ', '? ', '\n\n']:
last_punct = text.rfind(punct, start, end)
if last_punct != -1:
end = last_punct + len(punct)
break
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
start = end - overlap if end < text_length else text_length
return chunks
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
"""Calculate cosine similarity between two vectors."""
import math
dot_product = sum(a * b for a, b in zip(vec1, vec2))
magnitude1 = math.sqrt(sum(a * a for a in vec1))
magnitude2 = math.sqrt(sum(b * b for b in vec2))
if magnitude1 == 0 or magnitude2 == 0:
return 0.0
return dot_product / (magnitude1 * magnitude2)
def _generate_document_id(self, user_id: str, site_id: str, file_path: str) -> str:
"""Generate unique document ID."""
combined = f"{user_id}:{site_id}:{file_path}"
return hashlib.sha256(combined.encode()).hexdigest()
def _remove_document_chunks(self, document_id: str):
"""Remove all chunks for a document."""
chunk_ids_to_remove = [
chunk_id for chunk_id, chunk in self.chunks.items()
if chunk.document_id == document_id
]
for chunk_id in chunk_ids_to_remove:
del self.chunks[chunk_id]
def _update_tags(self, document_id: str, old_tags: List[str], new_tags: List[str]):
"""Update tag index when tags change."""
# Remove from old tags
for tag in old_tags:
tag_lower = tag.lower()
if tag_lower in self.tag_index:
self.tag_index[tag_lower] = [
d for d in self.tag_index[tag_lower] if d != document_id
]
# Add to new tags
for tag in new_tags:
tag_lower = tag.lower()
if tag_lower not in self.tag_index:
self.tag_index[tag_lower] = []
if document_id not in self.tag_index[tag_lower]:
self.tag_index[tag_lower].append(document_id)
def create_embedding_provider(provider: str = "ollama", **kwargs) -> EmbeddingProvider:
"""
Factory function to create embedding provider.
Args:
provider: Provider name ("ollama" or "openai")
**kwargs: Provider-specific configuration
Returns:
EmbeddingProvider instance
Example:
# Ollama (default)
embeddings = create_embedding_provider("ollama", model="nomic-embed-text")
# OpenAI
embeddings = create_embedding_provider("openai", api_key="sk-...")
"""
if provider.lower() == "ollama":
return OllamaEmbeddings(
base_url=kwargs.get("base_url", os.getenv("OLLAMA_URL", "http://localhost:11434")),
model=kwargs.get("model", os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text"))
)
elif provider.lower() == "openai":
api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY"))
if not api_key:
raise ValueError("OpenAI API key required")
return OpenAIEmbeddings(
api_key=api_key,
model=kwargs.get("model", "text-embedding-3-small")
)
else:
raise ValueError(f"Unknown embedding provider: {provider}")