Some checks failed
CI - SharePoint Plugin with SonarQube / Test and SonarQube Analysis (push) Has been cancelled
478 lines
16 KiB
Python
478 lines
16 KiB
Python
"""
|
|
Vector Store for Multi-Document Chat with RAG (Retrieval-Augmented Generation)
|
|
|
|
Supports document embeddings, tagging, and semantic search for chatting with
|
|
multiple documents at once.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
from typing import List, Dict, Any, Optional
|
|
from dataclasses import dataclass, asdict
|
|
from datetime import datetime
|
|
import hashlib
|
|
|
|
|
|
@dataclass
|
|
class DocumentChunk:
|
|
"""A chunk of a document with metadata."""
|
|
chunk_id: str
|
|
document_id: str
|
|
content: str
|
|
chunk_index: int
|
|
embedding: Optional[List[float]] = None
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
@dataclass
|
|
class Document:
|
|
"""Document metadata with tags."""
|
|
document_id: str
|
|
user_id: str
|
|
site_id: str
|
|
file_path: str
|
|
filename: str
|
|
tags: List[str]
|
|
content_hash: str
|
|
created_at: str
|
|
updated_at: str
|
|
chunk_count: int
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class EmbeddingProvider:
|
|
"""Base class for embedding providers."""
|
|
|
|
def embed_text(self, text: str) -> List[float]:
|
|
"""Generate embedding for text."""
|
|
raise NotImplementedError
|
|
|
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
"""Generate embeddings for multiple texts."""
|
|
return [self.embed_text(text) for text in texts]
|
|
|
|
|
|
class OllamaEmbeddings(EmbeddingProvider):
|
|
"""Ollama embeddings using local models."""
|
|
|
|
def __init__(self, base_url: str = "http://localhost:11434", model: str = "nomic-embed-text"):
|
|
"""
|
|
Initialize Ollama embeddings.
|
|
|
|
Args:
|
|
base_url: Ollama server URL
|
|
model: Embedding model (e.g., nomic-embed-text, mxbai-embed-large)
|
|
"""
|
|
import requests
|
|
self.base_url = base_url.rstrip('/')
|
|
self.model = model
|
|
self.requests = requests
|
|
|
|
def embed_text(self, text: str) -> List[float]:
|
|
"""Generate embedding using Ollama."""
|
|
response = self.requests.post(
|
|
f"{self.base_url}/api/embeddings",
|
|
json={
|
|
"model": self.model,
|
|
"prompt": text
|
|
}
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()["embedding"]
|
|
|
|
|
|
class OpenAIEmbeddings(EmbeddingProvider):
|
|
"""OpenAI embeddings."""
|
|
|
|
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
|
|
"""
|
|
Initialize OpenAI embeddings.
|
|
|
|
Args:
|
|
api_key: OpenAI API key
|
|
model: Embedding model (text-embedding-3-small or text-embedding-3-large)
|
|
"""
|
|
import requests
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self.requests = requests
|
|
|
|
def embed_text(self, text: str) -> List[float]:
|
|
"""Generate embedding using OpenAI."""
|
|
response = self.requests.post(
|
|
"https://api.openai.com/v1/embeddings",
|
|
headers={
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"input": text,
|
|
"model": self.model
|
|
}
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()["data"][0]["embedding"]
|
|
|
|
|
|
class InMemoryVectorStore:
|
|
"""
|
|
Simple in-memory vector store for development.
|
|
|
|
For production, use a proper vector database like:
|
|
- Pinecone
|
|
- Weaviate
|
|
- Qdrant
|
|
- ChromaDB
|
|
- pgvector (PostgreSQL extension)
|
|
"""
|
|
|
|
def __init__(self, embedding_provider: EmbeddingProvider):
|
|
"""Initialize vector store."""
|
|
self.embedding_provider = embedding_provider
|
|
self.documents: Dict[str, Document] = {}
|
|
self.chunks: Dict[str, DocumentChunk] = {}
|
|
self.user_documents: Dict[str, List[str]] = {} # user_id -> [document_ids]
|
|
self.tag_index: Dict[str, List[str]] = {} # tag -> [document_ids]
|
|
|
|
def add_document(
|
|
self,
|
|
user_id: str,
|
|
site_id: str,
|
|
file_path: str,
|
|
filename: str,
|
|
content: str,
|
|
tags: List[str],
|
|
chunk_size: int = 1000,
|
|
chunk_overlap: int = 200,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> str:
|
|
"""
|
|
Add a document to the vector store with chunking and embeddings.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
site_id: SharePoint site ID
|
|
file_path: File path
|
|
filename: Filename
|
|
content: Document content
|
|
tags: List of tags (e.g., ["HR", "SALES", "Q4-2024"])
|
|
chunk_size: Size of each chunk in characters
|
|
chunk_overlap: Overlap between chunks
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
document_id
|
|
"""
|
|
# Generate document ID
|
|
document_id = self._generate_document_id(user_id, site_id, file_path)
|
|
|
|
# Calculate content hash
|
|
content_hash = hashlib.sha256(content.encode()).hexdigest()
|
|
|
|
# Check if document already exists and hasn't changed
|
|
if document_id in self.documents:
|
|
existing_doc = self.documents[document_id]
|
|
if existing_doc.content_hash == content_hash:
|
|
# Update tags if different
|
|
if set(tags) != set(existing_doc.tags):
|
|
self._update_tags(document_id, existing_doc.tags, tags)
|
|
existing_doc.tags = tags
|
|
existing_doc.updated_at = datetime.utcnow().isoformat()
|
|
return document_id
|
|
else:
|
|
# Content changed, remove old chunks
|
|
self._remove_document_chunks(document_id)
|
|
|
|
# Chunk the document
|
|
chunks = self._chunk_text(content, chunk_size, chunk_overlap)
|
|
|
|
# Generate embeddings for all chunks
|
|
chunk_texts = [chunk for chunk in chunks]
|
|
embeddings = self.embedding_provider.embed_batch(chunk_texts)
|
|
|
|
# Create document chunks
|
|
document_chunks = []
|
|
for idx, (chunk_text, embedding) in enumerate(zip(chunk_texts, embeddings)):
|
|
chunk_id = f"{document_id}_chunk_{idx}"
|
|
chunk = DocumentChunk(
|
|
chunk_id=chunk_id,
|
|
document_id=document_id,
|
|
content=chunk_text,
|
|
chunk_index=idx,
|
|
embedding=embedding,
|
|
metadata=metadata
|
|
)
|
|
self.chunks[chunk_id] = chunk
|
|
document_chunks.append(chunk)
|
|
|
|
# Create document metadata
|
|
now = datetime.utcnow().isoformat()
|
|
document = Document(
|
|
document_id=document_id,
|
|
user_id=user_id,
|
|
site_id=site_id,
|
|
file_path=file_path,
|
|
filename=filename,
|
|
tags=tags,
|
|
content_hash=content_hash,
|
|
created_at=now,
|
|
updated_at=now,
|
|
chunk_count=len(document_chunks),
|
|
metadata=metadata
|
|
)
|
|
|
|
# Store document
|
|
self.documents[document_id] = document
|
|
|
|
# Update user index
|
|
if user_id not in self.user_documents:
|
|
self.user_documents[user_id] = []
|
|
if document_id not in self.user_documents[user_id]:
|
|
self.user_documents[user_id].append(document_id)
|
|
|
|
# Update tag index
|
|
for tag in tags:
|
|
tag_lower = tag.lower()
|
|
if tag_lower not in self.tag_index:
|
|
self.tag_index[tag_lower] = []
|
|
if document_id not in self.tag_index[tag_lower]:
|
|
self.tag_index[tag_lower].append(document_id)
|
|
|
|
return document_id
|
|
|
|
def search(
|
|
self,
|
|
user_id: str,
|
|
query: str,
|
|
tags: Optional[List[str]] = None,
|
|
top_k: int = 5
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Search for relevant document chunks.
|
|
|
|
Args:
|
|
user_id: User ID (for access control)
|
|
query: Search query
|
|
tags: Optional list of tags to filter by
|
|
top_k: Number of results to return
|
|
|
|
Returns:
|
|
List of relevant chunks with similarity scores
|
|
"""
|
|
# Get query embedding
|
|
query_embedding = self.embedding_provider.embed_text(query)
|
|
|
|
# Get candidate document IDs
|
|
if tags:
|
|
# Filter by tags
|
|
candidate_doc_ids = set()
|
|
for tag in tags:
|
|
tag_lower = tag.lower()
|
|
if tag_lower in self.tag_index:
|
|
candidate_doc_ids.update(self.tag_index[tag_lower])
|
|
# Intersect with user's documents
|
|
user_doc_ids = set(self.user_documents.get(user_id, []))
|
|
candidate_doc_ids = candidate_doc_ids.intersection(user_doc_ids)
|
|
else:
|
|
# All user's documents
|
|
candidate_doc_ids = set(self.user_documents.get(user_id, []))
|
|
|
|
# Get chunks from candidate documents
|
|
candidate_chunks = [
|
|
chunk for chunk in self.chunks.values()
|
|
if chunk.document_id in candidate_doc_ids
|
|
]
|
|
|
|
# Calculate cosine similarity for each chunk
|
|
results = []
|
|
for chunk in candidate_chunks:
|
|
if chunk.embedding:
|
|
similarity = self._cosine_similarity(query_embedding, chunk.embedding)
|
|
results.append({
|
|
"chunk": chunk,
|
|
"document": self.documents[chunk.document_id],
|
|
"similarity": similarity
|
|
})
|
|
|
|
# Sort by similarity and return top_k
|
|
results.sort(key=lambda x: x["similarity"], reverse=True)
|
|
return results[:top_k]
|
|
|
|
def get_documents_by_tags(self, user_id: str, tags: List[str]) -> List[Document]:
|
|
"""Get all documents with specific tags."""
|
|
doc_ids = set()
|
|
for tag in tags:
|
|
tag_lower = tag.lower()
|
|
if tag_lower in self.tag_index:
|
|
doc_ids.update(self.tag_index[tag_lower])
|
|
|
|
# Filter by user
|
|
user_doc_ids = set(self.user_documents.get(user_id, []))
|
|
doc_ids = doc_ids.intersection(user_doc_ids)
|
|
|
|
return [self.documents[doc_id] for doc_id in doc_ids if doc_id in self.documents]
|
|
|
|
def list_tags(self, user_id: str) -> Dict[str, int]:
|
|
"""List all tags for user with document counts."""
|
|
user_doc_ids = set(self.user_documents.get(user_id, []))
|
|
tag_counts = {}
|
|
|
|
for tag, doc_ids in self.tag_index.items():
|
|
user_tagged_docs = set(doc_ids).intersection(user_doc_ids)
|
|
if user_tagged_docs:
|
|
tag_counts[tag] = len(user_tagged_docs)
|
|
|
|
return tag_counts
|
|
|
|
def update_document_tags(self, document_id: str, user_id: str, tags: List[str]):
|
|
"""Update tags for a document."""
|
|
if document_id not in self.documents:
|
|
raise ValueError("Document not found")
|
|
|
|
doc = self.documents[document_id]
|
|
if doc.user_id != user_id:
|
|
raise ValueError("Access denied")
|
|
|
|
old_tags = doc.tags
|
|
self._update_tags(document_id, old_tags, tags)
|
|
doc.tags = tags
|
|
doc.updated_at = datetime.utcnow().isoformat()
|
|
|
|
def remove_document(self, document_id: str, user_id: str):
|
|
"""Remove a document and its chunks."""
|
|
if document_id not in self.documents:
|
|
return
|
|
|
|
doc = self.documents[document_id]
|
|
if doc.user_id != user_id:
|
|
raise ValueError("Access denied")
|
|
|
|
# Remove from tag index
|
|
for tag in doc.tags:
|
|
tag_lower = tag.lower()
|
|
if tag_lower in self.tag_index:
|
|
self.tag_index[tag_lower] = [
|
|
d for d in self.tag_index[tag_lower] if d != document_id
|
|
]
|
|
|
|
# Remove chunks
|
|
self._remove_document_chunks(document_id)
|
|
|
|
# Remove from user index
|
|
if user_id in self.user_documents:
|
|
self.user_documents[user_id] = [
|
|
d for d in self.user_documents[user_id] if d != document_id
|
|
]
|
|
|
|
# Remove document
|
|
del self.documents[document_id]
|
|
|
|
def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
|
|
"""Split text into overlapping chunks."""
|
|
if not text:
|
|
return []
|
|
|
|
chunks = []
|
|
start = 0
|
|
text_length = len(text)
|
|
|
|
while start < text_length:
|
|
end = start + chunk_size
|
|
|
|
# Try to break at sentence boundary
|
|
if end < text_length:
|
|
# Look for sentence endings
|
|
for punct in ['. ', '! ', '? ', '\n\n']:
|
|
last_punct = text.rfind(punct, start, end)
|
|
if last_punct != -1:
|
|
end = last_punct + len(punct)
|
|
break
|
|
|
|
chunk = text[start:end].strip()
|
|
if chunk:
|
|
chunks.append(chunk)
|
|
|
|
start = end - overlap if end < text_length else text_length
|
|
|
|
return chunks
|
|
|
|
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
|
"""Calculate cosine similarity between two vectors."""
|
|
import math
|
|
|
|
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
|
magnitude1 = math.sqrt(sum(a * a for a in vec1))
|
|
magnitude2 = math.sqrt(sum(b * b for b in vec2))
|
|
|
|
if magnitude1 == 0 or magnitude2 == 0:
|
|
return 0.0
|
|
|
|
return dot_product / (magnitude1 * magnitude2)
|
|
|
|
def _generate_document_id(self, user_id: str, site_id: str, file_path: str) -> str:
|
|
"""Generate unique document ID."""
|
|
combined = f"{user_id}:{site_id}:{file_path}"
|
|
return hashlib.sha256(combined.encode()).hexdigest()
|
|
|
|
def _remove_document_chunks(self, document_id: str):
|
|
"""Remove all chunks for a document."""
|
|
chunk_ids_to_remove = [
|
|
chunk_id for chunk_id, chunk in self.chunks.items()
|
|
if chunk.document_id == document_id
|
|
]
|
|
for chunk_id in chunk_ids_to_remove:
|
|
del self.chunks[chunk_id]
|
|
|
|
def _update_tags(self, document_id: str, old_tags: List[str], new_tags: List[str]):
|
|
"""Update tag index when tags change."""
|
|
# Remove from old tags
|
|
for tag in old_tags:
|
|
tag_lower = tag.lower()
|
|
if tag_lower in self.tag_index:
|
|
self.tag_index[tag_lower] = [
|
|
d for d in self.tag_index[tag_lower] if d != document_id
|
|
]
|
|
|
|
# Add to new tags
|
|
for tag in new_tags:
|
|
tag_lower = tag.lower()
|
|
if tag_lower not in self.tag_index:
|
|
self.tag_index[tag_lower] = []
|
|
if document_id not in self.tag_index[tag_lower]:
|
|
self.tag_index[tag_lower].append(document_id)
|
|
|
|
|
|
def create_embedding_provider(provider: str = "ollama", **kwargs) -> EmbeddingProvider:
|
|
"""
|
|
Factory function to create embedding provider.
|
|
|
|
Args:
|
|
provider: Provider name ("ollama" or "openai")
|
|
**kwargs: Provider-specific configuration
|
|
|
|
Returns:
|
|
EmbeddingProvider instance
|
|
|
|
Example:
|
|
# Ollama (default)
|
|
embeddings = create_embedding_provider("ollama", model="nomic-embed-text")
|
|
|
|
# OpenAI
|
|
embeddings = create_embedding_provider("openai", api_key="sk-...")
|
|
"""
|
|
if provider.lower() == "ollama":
|
|
return OllamaEmbeddings(
|
|
base_url=kwargs.get("base_url", os.getenv("OLLAMA_URL", "http://localhost:11434")),
|
|
model=kwargs.get("model", os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text"))
|
|
)
|
|
elif provider.lower() == "openai":
|
|
api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY"))
|
|
if not api_key:
|
|
raise ValueError("OpenAI API key required")
|
|
return OpenAIEmbeddings(
|
|
api_key=api_key,
|
|
model=kwargs.get("model", "text-embedding-3-small")
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown embedding provider: {provider}")
|