Initial commit: SharePoint connector and ToothFairyAI integration
Some checks failed
CI - SharePoint Plugin with SonarQube / Test and SonarQube Analysis (push) Has been cancelled

This commit is contained in:
2026-02-22 17:58:45 +02:00
commit bcd0f8a227
29 changed files with 9410 additions and 0 deletions

33
.env.example Normal file
View File

@@ -0,0 +1,33 @@
# SharePoint Connector Configuration
# Azure App Registration (REQUIRED)
SHAREPOINT_CLIENT_ID=your-client-id-here
SHAREPOINT_CLIENT_SECRET=your-client-secret-here
SHAREPOINT_TENANT_ID=common
# OAuth Callback (REQUIRED - must match Azure app registration)
# Local development:
REDIRECT_URI=http://localhost:5000/sharepoint/callback
# Production:
# REDIRECT_URI=https://yourdomain.com/sharepoint/callback
# Security (REQUIRED)
# Generate with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
ENCRYPTION_KEY=your-generated-encryption-key-here
# Flask Security (OPTIONAL - auto-generated if not set)
# Generate with: python -c "import secrets; print(secrets.token_urlsafe(32))"
FLASK_SECRET_KEY=your-generated-flask-secret-here
# AWS Configuration (REQUIRED for DynamoDB)
AWS_REGION=us-east-1
# AWS_ACCESS_KEY_ID=your-key-id # Optional if using IAM role
# AWS_SECRET_ACCESS_KEY=your-secret-key # Optional if using IAM role
# DynamoDB Settings (OPTIONAL)
TABLE_PREFIX= # Optional: prefix table names (e.g., "dev_" or "prod_")
# DYNAMODB_ENDPOINT=http://localhost:8000 # Uncomment for local DynamoDB
# Flask Settings (OPTIONAL)
FLASK_DEBUG=false # Set to "true" for development
PORT=5000

View File

@@ -0,0 +1,96 @@
on:
push:
branches:
- main
- develop
- feature/*
pull_request:
branches:
- main
- develop
- feature/*
name: CI - SharePoint Plugin with SonarQube
jobs:
test-and-scan:
name: Test and SonarQube Analysis
runs-on: ubuntu-latest
defaults:
run:
working-directory: plugins/sharepoint
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Cache pip dependencies
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('plugins/sharepoint/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov pytest-mock
- name: Run tests with coverage
run: |
pytest \
--cov=. \
--cov-report=xml:coverage.xml \
--cov-report=term \
--cov-report=html \
-v \
--ignore=venv \
--ignore=.venv
continue-on-error: false
- name: Verify coverage file
run: |
pwd
ls -la
if [ -f coverage.xml ]; then
echo "✅ Coverage file exists"
echo "First 30 lines of coverage.xml:"
head -30 coverage.xml
else
echo "❌ ERROR: Coverage file was not created!"
exit 1
fi
- name: SonarQube Scan
uses: SonarSource/sonarqube-scan-action@v6
with:
projectBaseDir: plugins/sharepoint
args: >
-Dsonar.projectKey=sharepoint-connector
-Dsonar.projectName=SharePoint Connector Plugin
-Dsonar.sources=.
-Dsonar.python.coverage.reportPaths=coverage.xml
-Dsonar.tests=.
-Dsonar.test.inclusions=test_*.py
-Dsonar.exclusions=venv/**,**/__pycache__/**,*.pyc,.venv/**,htmlcov/**,templates/**,static/**
-Dsonar.python.version=3.11
env:
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
- name: SonarQube Quality Gate Check
uses: SonarSource/sonarqube-quality-gate-action@v1
timeout-minutes: 5
env:
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}

44
.gitignore vendored Normal file
View File

@@ -0,0 +1,44 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
# Virtual Environment
venv/
env/
ENV/
.venv
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# Environment variables
.env
.env.local
# Logs
*.log
# OS
.DS_Store
Thumbs.db
# Application data
/home/data/
# Pytest
.pytest_cache/
.coverage
coverage.xml
htmlcov/
# Distribution
dist/
build/
*.egg-info/

23
Dockerfile Normal file
View File

@@ -0,0 +1,23 @@
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application files
COPY app.py .
COPY saas_connector_dynamodb.py .
COPY templates/ templates/
COPY static/ static/
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=40s --retries=3 \
CMD python -c "import requests; requests.get('http://localhost:8000/health')"
# Run with gunicorn (production-grade WSGI server)
CMD ["gunicorn", "--bind", "0.0.0.0:8000", "--workers", "4", "--timeout", "120", "app:app"]

67
README.md Normal file
View File

@@ -0,0 +1,67 @@
# SharePoint Connector for SaaS Applications
Enterprise-grade SharePoint integration for SaaS applications. Features secure OAuth 2.0, multi-tenant isolation, automatic text extraction, and AI-powered document chat (RAG) using local or cloud LLMs.
## Quick Start - Local Development
### 1. Install Dependencies
```bash
python3 -m venv venv
source venv/bin/activate # Windows: .\venv\Scripts\activate
pip install -r requirements.txt
2. Start Required Services (Docker)
You need local instances of DynamoDB (for configurations/sessions) and PostgreSQL with pgvector (for document embeddings).
# Start DynamoDB Local
docker run -p 8000:8000 amazon/dynamodb-local
# Start PostgreSQL with pgvector
docker run -d --name postgres-vector -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=sharepoint_vectors -p 5432:5432 pgvector/pgvector:pg16
3. Start Ollama (For AI & Embeddings)
Ensure Ollama is installed (ollama.ai) and pull the necessary models:
ollama pull llama3.2
ollama pull nomic-embed-text
4. Configure Environment Variables
Create a .env file in the root directory. Here are the core variables you need:
# --- CORE APP ---
AS BEFORE!
# --- TOOTHFAIRYAI AGENT INTEGRATION (NEW, ALL FIELDS REQUIRED) ---
TOOTHFAIRYAI_API_KEY=your_api_key_here
TOOTHFAIRYAI_WORKSPACE_ID=your_workspace_uuid_here
TOOTHFAIRYAI_API_URL=[https://api.toothfairyai.com](https://api.toothfairyai.com)
NGROK_URL=[https://your-url.ngrok-free.app](https://your-url.ngrok-free.app)
5. Run the Application
python app_dev.py
Open http://localhost:5001 in your browser. Enter your Azure Client ID, Client Secret, and Tenant ID in the UI to connect your SharePoint account.
Azure App Registration Setup
To connect to SharePoint, you must register an app in the Azure Portal:
Go to Azure Active Directory -> App registrations -> New registration.
Set Supported account types to Multi-tenant.
Set Redirect URI to http://localhost:5001/sharepoint/callback (update for production).
Save your Application (client) ID and Directory (tenant) ID.
Under Certificates & secrets, create a new client secret and copy the Value.
Under API permissions, add Delegated permissions for Microsoft Graph: User.Read, Sites.Read.All, Files.Read.All, offline_access.
Production Deployment (AWS ECS)
For production, the app is designed to run on AWS ECS Fargate.
Database: Use a managed AWS DynamoDB table and an RDS PostgreSQL instance.
Environment: Run python app.py instead of app_dev.py.
Security: Assign an IAM Task Role to the ECS container with strict DynamoDB permissions. Store the FLASK_SECRET_KEY in AWS Secrets Manager.
Routing: Place the ECS service behind an Application Load Balancer (ALB) configured with HTTPS.

1014
app.py Normal file

File diff suppressed because it is too large Load Diff

1094
app_dev.py Normal file

File diff suppressed because it is too large Load Diff

308
background_indexer.py Normal file
View File

@@ -0,0 +1,308 @@
"""
Background Document Indexer
Indexes SharePoint files in the background without blocking the UI.
Uses threading for simple deployment (no Celery/Redis required).
"""
import threading
import time
from typing import Dict, List, Optional
from datetime import datetime
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class IndexingJob:
"""Represents a background indexing job."""
def __init__(self, job_id: str, site_id: str, site_name: str, connection_id: str, user_id: str):
self.job_id = job_id
self.site_id = site_id
self.site_name = site_name
self.connection_id = connection_id
self.user_id = user_id
self.status = "pending" # pending, running, completed, failed
self.progress = 0 # 0-100
self.total_files = 0
self.processed_files = 0
self.successful_files = 0
self.failed_files = 0
self.started_at = None
self.completed_at = None
self.error = None
self.current_file = None
class BackgroundIndexer:
"""
Background indexer for SharePoint documents.
Runs indexing jobs in separate threads to avoid blocking the main application.
"""
def __init__(self):
self.jobs: Dict[str, IndexingJob] = {}
self.active_threads: Dict[str, threading.Thread] = {}
self.lock = threading.Lock()
def start_indexing(
self,
job_id: str,
site_id: str,
site_name: str,
connection_id: str,
user_id: str,
connector,
vector_store,
document_parser,
path: str = "",
tags: List[str] = None
) -> IndexingJob:
"""
Start a background indexing job for a SharePoint site.
Args:
job_id: Unique job identifier
site_id: SharePoint site ID
site_name: Site display name
connection_id: SharePoint connection ID
user_id: User ID
connector: SharePoint connector instance
vector_store: Vector store instance
document_parser: Document parser instance
path: Optional path to start indexing from
Returns:
IndexingJob instance
"""
with self.lock:
# Create job
job = IndexingJob(job_id, site_id, site_name, connection_id, user_id)
self.jobs[job_id] = job
# Start thread
thread = threading.Thread(
target=self._index_site,
args=(job, connector, vector_store, document_parser, path, tags or []),
daemon=True
)
self.active_threads[job_id] = thread
thread.start()
logger.info(f"Started indexing job {job_id} for site {site_name}")
return job
def get_job_status(self, job_id: str) -> Optional[Dict]:
"""Get the status of an indexing job."""
with self.lock:
job = self.jobs.get(job_id)
if not job:
return None
return {
"job_id": job.job_id,
"site_id": job.site_id,
"site_name": job.site_name,
"status": job.status,
"progress": job.progress,
"total_files": job.total_files,
"processed_files": job.processed_files,
"successful_files": job.successful_files,
"failed_files": job.failed_files,
"current_file": job.current_file,
"started_at": job.started_at.isoformat() if job.started_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
"error": job.error
}
def cancel_job(self, job_id: str) -> bool:
"""Cancel a running indexing job."""
with self.lock:
job = self.jobs.get(job_id)
if job and job.status == "running":
job.status = "cancelled"
job.completed_at = datetime.utcnow()
logger.info(f"Cancelled indexing job {job_id}")
return True
return False
def _index_site(self, job, connector, vector_store, document_parser, path="", tags=None):
"""
Index all files in a SharePoint site (runs in background thread).
This method:
1. Recursively lists all files in the site
2. Downloads and parses each file
3. Generates embeddings
4. Stores in vector store with specified tags
"""
from saas_connector_dynamodb import SecureSharePointClient
if tags is None:
tags = []
try:
job.status = "running"
job.started_at = datetime.utcnow()
# Create SharePoint client
client = SecureSharePointClient(connector, job.connection_id, job.user_id)
# First, count total files
logger.info(f"[Job {job.job_id}] Counting files in {job.site_name}...")
all_files = self._list_all_files_recursive(client, job, job.site_id, path)
if job.status == "cancelled":
return
job.total_files = len(all_files)
logger.info(f"[Job {job.job_id}] Found {job.total_files} files to index")
# Process each file
for file_info in all_files:
if job.status == "cancelled":
logger.info(f"[Job {job.job_id}] Job cancelled by user")
break
try:
self._process_file(
job,
client,
vector_store,
document_parser,
file_info,
tags
)
job.successful_files += 1
except Exception as e:
logger.error(f"[Job {job.job_id}] Failed to process {file_info['path']}: {e}")
job.failed_files += 1
job.processed_files += 1
job.progress = int((job.processed_files / job.total_files) * 100) if job.total_files > 0 else 0
# Mark as completed
if job.status != "cancelled":
job.status = "completed"
job.completed_at = datetime.utcnow()
logger.info(
f"[Job {job.job_id}] Completed: {job.successful_files} successful, "
f"{job.failed_files} failed out of {job.total_files} total"
)
except Exception as e:
logger.error(f"[Job {job.job_id}] Job failed with error: {e}")
job.status = "failed"
job.error = str(e)
job.completed_at = datetime.utcnow()
def _list_all_files_recursive(self, client, job, site_id, path=""):
"""Recursively list all files in a site."""
files_to_process = []
try:
items = client.list_files(site_id, path)
for item in items:
if job.status == "cancelled":
break
# If it's a folder, recurse
if 'folder' in item:
folder_name = item['name']
new_path = f"{path}/{folder_name}" if path else folder_name
files_to_process.extend(
self._list_all_files_recursive(client, job, site_id, new_path)
)
else:
# It's a file
file_path = f"{path}/{item['name']}" if path else item['name']
files_to_process.append({
'name': item['name'],
'path': file_path,
'size': item.get('size', 0)
})
except Exception as e:
logger.error(f"[Job {job.job_id}] Error listing files in {path}: {e}")
return files_to_process
def _process_file(self, job, client, vector_store, document_parser, file_info, tags=None):
"""Process a single file: download, parse, embed, store."""
if not vector_store:
logger.warning(f"[Job {job.job_id}] Vector store not available, skipping {file_info['name']}")
return
if tags is None:
tags = []
filename = file_info['name']
file_path = file_info['path']
job.current_file = filename
logger.info(f"[Job {job.job_id}] Processing {filename}...")
# Check if file can be parsed
if not document_parser.can_parse(filename):
logger.info(f"[Job {job.job_id}] Skipping unsupported file type: {filename}")
return
# Download file
binary_content = client.read_file(job.site_id, file_path, as_text=False)
# Parse content
try:
content = document_parser.parse(binary_content, filename)
except Exception as parse_err:
logger.warning(f"[Job {job.job_id}] Failed to parse {filename}: {parse_err}")
# Try fallback to text
try:
content = binary_content.decode('utf-8', errors='ignore')
except:
logger.error(f"[Job {job.job_id}] Could not decode {filename}, skipping")
return
# Skip error messages
if content.startswith("[") and content.endswith("]"):
logger.info(f"[Job {job.job_id}] Skipping error content for {filename}")
return
# Skip very small files (likely empty)
if len(content.strip()) < 50:
logger.info(f"[Job {job.job_id}] Skipping small file {filename} ({len(content)} chars)")
return
# Add to vector store with specified tags
try:
document_id = vector_store.add_document(
user_id=job.user_id,
site_id=job.site_id,
file_path=file_path,
filename=filename,
content=content,
tags=tags, # Apply tags from indexing request
chunk_size=1000,
chunk_overlap=200
)
logger.info(f"[Job {job.job_id}] Indexed {filename} with document_id {document_id} and tags {tags}")
except Exception as e:
logger.error(f"[Job {job.job_id}] Failed to add {filename} to vector store: {e}")
raise
# Global indexer instance
_indexer = None
def get_indexer() -> BackgroundIndexer:
"""Get the global indexer instance."""
global _indexer
if _indexer is None:
_indexer = BackgroundIndexer()
return _indexer

196
clear_data.py Normal file
View File

@@ -0,0 +1,196 @@
#!/usr/bin/env python3
"""
Clear all data from the SharePoint connector application.
This script will:
1. Drop all vector store tables (PostgreSQL)
2. Delete all DynamoDB tables (connections, tokens, etc.)
3. Clear stored credentials
"""
import os
import sys
import boto3
from storage.credentials_storage import get_credentials_storage
def clear_postgresql_vector_store():
"""Drop all vector store tables from PostgreSQL."""
try:
import psycopg2
# Get connection params from environment
db_host = os.getenv("POSTGRES_HOST", "localhost")
db_port = os.getenv("POSTGRES_PORT", "5432")
db_name = os.getenv("POSTGRES_DB", "sharepoint_vectors")
db_user = os.getenv("POSTGRES_USER", "postgres")
db_password = os.getenv("POSTGRES_PASSWORD", "postgres")
print(f"📊 Connecting to PostgreSQL at {db_host}:{db_port}/{db_name}...")
conn = psycopg2.connect(
host=db_host,
port=db_port,
database=db_name,
user=db_user,
password=db_password
)
cursor = conn.cursor()
# Get table prefix
table_prefix = os.getenv("TABLE_PREFIX", "dev_")
# Drop tables in correct order (child tables first)
tables = [
f"{table_prefix}document_tags",
f"{table_prefix}document_chunks",
f"{table_prefix}documents"
]
for table in tables:
try:
cursor.execute(f"DROP TABLE IF EXISTS {table} CASCADE")
print(f" ✅ Dropped table: {table}")
except Exception as e:
print(f" ⚠️ Could not drop {table}: {e}")
conn.commit()
cursor.close()
conn.close()
print("✅ PostgreSQL vector store cleared")
return True
except ImportError:
print("⚠️ psycopg2 not installed, skipping PostgreSQL cleanup")
return False
except Exception as e:
print(f"❌ Error clearing PostgreSQL: {e}")
return False
def clear_dynamodb_tables():
"""Delete all DynamoDB tables."""
try:
# Determine if using local or AWS DynamoDB
endpoint_url = os.getenv("DYNAMODB_ENDPOINT", "http://localhost:8000")
region = os.getenv("AWS_REGION", "ap-southeast-2")
table_prefix = os.getenv("TABLE_PREFIX", "dev_")
if "localhost" in endpoint_url or "127.0.0.1" in endpoint_url:
print(f"📊 Connecting to LOCAL DynamoDB at {endpoint_url}...")
else:
print(f"📊 Connecting to AWS DynamoDB in {region}...")
dynamodb = boto3.client(
'dynamodb',
region_name=region,
endpoint_url=endpoint_url if "localhost" in endpoint_url or "127.0.0.1" in endpoint_url else None
)
# List all tables
response = dynamodb.list_tables()
tables = response.get('TableNames', [])
# Filter tables with our prefix
our_tables = [t for t in tables if t.startswith(table_prefix)]
if not our_tables:
print(f" No tables found with prefix '{table_prefix}'")
return True
print(f" Found {len(our_tables)} tables to delete:")
for table in our_tables:
print(f" - {table}")
# Delete each table
for table in our_tables:
try:
dynamodb.delete_table(TableName=table)
print(f" ✅ Deleted table: {table}")
except Exception as e:
print(f" ⚠️ Could not delete {table}: {e}")
print("✅ DynamoDB tables cleared")
return True
except Exception as e:
print(f"❌ Error clearing DynamoDB: {e}")
return False
def clear_credentials():
"""Clear all stored credentials."""
try:
print("📊 Clearing stored credentials...")
storage = get_credentials_storage()
# Get the storage file path
if hasattr(storage, 'file_path'):
file_path = storage.file_path
if os.path.exists(file_path):
os.remove(file_path)
print(f" ✅ Deleted credentials file: {file_path}")
else:
print(f" No credentials file found at: {file_path}")
else:
print(" Credentials storage does not use file system")
print("✅ Credentials cleared")
return True
except Exception as e:
print(f"❌ Error clearing credentials: {e}")
return False
def main():
"""Main function to clear all data."""
print("=" * 60)
print("🧹 CLEARING ALL SHAREPOINT CONNECTOR DATA")
print("=" * 60)
print()
# Check for --force flag
force = '--force' in sys.argv or '-f' in sys.argv
if not force:
try:
response = input("⚠️ This will DELETE ALL data. Are you sure? (yes/no): ")
if response.lower() not in ['yes', 'y']:
print("❌ Aborted")
return
except EOFError:
print("\n❌ No input received. Use --force to skip confirmation.")
return
print()
print("Starting cleanup...")
print()
# Clear PostgreSQL vector store
print("1⃣ Clearing PostgreSQL Vector Store...")
clear_postgresql_vector_store()
print()
# Clear DynamoDB tables
print("2⃣ Clearing DynamoDB Tables...")
clear_dynamodb_tables()
print()
# Clear credentials
print("3⃣ Clearing Stored Credentials...")
clear_credentials()
print()
print("=" * 60)
print("✅ ALL DATA CLEARED SUCCESSFULLY")
print("=" * 60)
print()
print("You can now start fresh by running the application:")
print(" python app_dev.py")
print()
if __name__ == "__main__":
main()

54
deploy.sh Normal file
View File

@@ -0,0 +1,54 @@
#!/bin/bash
# SharePoint Connector - AWS ECS Deployment Script
set -e
# Configuration
AWS_REGION="ap-southeast-2"
AWS_ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text)
ECR_REPOSITORY="sharepoint-connector"
ECS_CLUSTER="your-cluster-name"
ECS_SERVICE="sharepoint-connector-service"
TASK_FAMILY="sharepoint-connector"
echo "🚀 Deploying SharePoint Connector to AWS ECS"
echo "📍 Region: $AWS_REGION"
echo "🏷️ Account: $AWS_ACCOUNT_ID"
echo ""
# Step 1: Build Docker image
echo "🔨 Building Docker image..."
docker build -t $ECR_REPOSITORY:latest .
# Step 2: Login to ECR
echo "🔐 Logging in to ECR..."
aws ecr get-login-password --region $AWS_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_REGION.amazonaws.com
# Step 3: Create ECR repository if it doesn't exist
echo "📦 Ensuring ECR repository exists..."
aws ecr describe-repositories --repository-names $ECR_REPOSITORY --region $AWS_REGION 2>/dev/null || \
aws ecr create-repository --repository-name $ECR_REPOSITORY --region $AWS_REGION
# Step 4: Tag and push image
echo "📤 Pushing image to ECR..."
docker tag $ECR_REPOSITORY:latest $AWS_ACCOUNT_ID.dkr.ecr.$AWS_REGION.amazonaws.com/$ECR_REPOSITORY:latest
docker push $AWS_ACCOUNT_ID.dkr.ecr.$AWS_REGION.amazonaws.com/$ECR_REPOSITORY:latest
# Step 5: Update task definition
echo "📝 Updating ECS task definition..."
TASK_DEFINITION=$(cat task-definition.json | sed "s/YOUR_ACCOUNT_ID/$AWS_ACCOUNT_ID/g")
aws ecs register-task-definition --cli-input-json "$TASK_DEFINITION" --region $AWS_REGION
# Step 6: Update ECS service
echo "🔄 Updating ECS service..."
aws ecs update-service \
--cluster $ECS_CLUSTER \
--service $ECS_SERVICE \
--task-definition $TASK_FAMILY \
--force-new-deployment \
--region $AWS_REGION
echo ""
echo "✅ Deployment complete!"
echo "📊 Check status: aws ecs describe-services --cluster $ECS_CLUSTER --services $ECS_SERVICE --region $AWS_REGION"

30
docker-compose.yml Normal file
View File

@@ -0,0 +1,30 @@
version: '3.8'
services:
# DynamoDB Local for development
dynamodb-local:
image: amazon/dynamodb-local
container_name: dynamodb-local
ports:
- "8000:8000"
command: "-jar DynamoDBLocal.jar -sharedDb -inMemory"
# Development app
app-dev:
build:
context: .
dockerfile: Dockerfile
container_name: sharepoint-connector-dev
ports:
- "5000:5000"
environment:
- AWS_REGION=ap-southeast-2
- AWS_ACCESS_KEY_ID=fakeAccessKeyId
- AWS_SECRET_ACCESS_KEY=fakeSecretAccessKey
volumes:
- ./app_dev.py:/app/app.py
- ./templates:/app/templates
- ./static:/app/static
depends_on:
- dynamodb-local
command: python app.py

245
document_parser.py Normal file
View File

@@ -0,0 +1,245 @@
"""
Document Parser - Extract text from various file formats
Supports: PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx),
CSV, text files, and more
"""
import io
import mimetypes
from typing import Optional
class DocumentParser:
"""Parse different document types and extract text content for LLM processing."""
def __init__(self):
"""Initialize document parser."""
self.supported_extensions = {
# Text formats
'.txt', '.md', '.csv', '.json', '.xml', '.html', '.css', '.js',
'.py', '.java', '.cpp', '.c', '.h', '.sh', '.yaml', '.yml',
# Microsoft Office
'.docx', '.xlsx', '.pptx',
# PDF
'.pdf',
# Other
'.rtf', '.log'
}
def can_parse(self, filename: str) -> bool:
"""Check if file can be parsed."""
ext = self._get_extension(filename)
return ext in self.supported_extensions
def parse(self, content: bytes, filename: str) -> str:
"""
Parse document and extract text content.
Args:
content: File content as bytes
filename: Original filename (used to determine file type)
Returns:
Extracted text content
Raises:
ValueError: If file type is not supported
"""
ext = self._get_extension(filename)
if ext not in self.supported_extensions:
raise ValueError(f"Unsupported file type: {ext}")
# Text files - direct decode
if ext in {'.txt', '.md', '.csv', '.json', '.xml', '.html', '.css',
'.js', '.py', '.java', '.cpp', '.c', '.h', '.sh', '.yaml',
'.yml', '.log', '.rtf'}:
return self._parse_text(content)
# PDF
elif ext == '.pdf':
return self._parse_pdf(content)
# Microsoft Word
elif ext == '.docx':
return self._parse_docx(content)
# Microsoft Excel
elif ext == '.xlsx':
return self._parse_xlsx(content)
# Microsoft PowerPoint
elif ext == '.pptx':
return self._parse_pptx(content)
else:
raise ValueError(f"Parser not implemented for: {ext}")
def _get_extension(self, filename: str) -> str:
"""Get file extension in lowercase."""
return '.' + filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
def _parse_text(self, content: bytes) -> str:
"""Parse plain text files."""
# Try multiple encodings
for encoding in ['utf-8', 'latin-1', 'cp1252', 'ascii']:
try:
return content.decode(encoding)
except (UnicodeDecodeError, AttributeError):
continue
# Fallback: decode with errors='ignore'
return content.decode('utf-8', errors='ignore')
def _parse_pdf(self, content: bytes) -> str:
"""Parse PDF files."""
try:
import PyPDF2
pdf_file = io.BytesIO(content)
reader = PyPDF2.PdfReader(pdf_file)
text_parts = []
for page_num, page in enumerate(reader.pages):
text = page.extract_text()
if text.strip():
text_parts.append(f"--- Page {page_num + 1} ---\n{text}")
return '\n\n'.join(text_parts) if text_parts else "(empty PDF)"
except ImportError:
return "[PDF parsing requires PyPDF2: pip install PyPDF2]"
except Exception as e:
return f"[Error parsing PDF: {str(e)}]"
def _parse_docx(self, content: bytes) -> str:
"""Parse Word documents (.docx)."""
try:
import docx
doc_file = io.BytesIO(content)
doc = docx.Document(doc_file)
text_parts = []
# Extract paragraphs
for para in doc.paragraphs:
if para.text.strip():
text_parts.append(para.text)
# Extract tables
for table in doc.tables:
table_text = []
for row in table.rows:
row_text = [cell.text.strip() for cell in row.cells]
table_text.append(' | '.join(row_text))
if table_text:
text_parts.append('\n' + '\n'.join(table_text))
return '\n\n'.join(text_parts) if text_parts else "(empty document)"
except ImportError:
return "[Word parsing requires python-docx: pip install python-docx]"
except Exception as e:
return f"[Error parsing Word document: {str(e)}]"
def _parse_xlsx(self, content: bytes) -> str:
"""Parse Excel spreadsheets (.xlsx)."""
try:
import openpyxl
excel_file = io.BytesIO(content)
workbook = openpyxl.load_workbook(excel_file, data_only=True)
text_parts = []
for sheet_name in workbook.sheetnames:
sheet = workbook[sheet_name]
sheet_text = [f"=== Sheet: {sheet_name} ==="]
# Get data rows
for row in sheet.iter_rows(values_only=True):
# Skip empty rows
if any(cell is not None for cell in row):
row_text = ' | '.join(str(cell) if cell is not None else '' for cell in row)
sheet_text.append(row_text)
if len(sheet_text) > 1: # Has content beyond header
text_parts.append('\n'.join(sheet_text))
return '\n\n'.join(text_parts) if text_parts else "(empty spreadsheet)"
except ImportError:
return "[Excel parsing requires openpyxl: pip install openpyxl]"
except Exception as e:
return f"[Error parsing Excel file: {str(e)}]"
def _parse_pptx(self, content: bytes) -> str:
"""Parse PowerPoint presentations (.pptx)."""
try:
from pptx import Presentation
ppt_file = io.BytesIO(content)
prs = Presentation(ppt_file)
text_parts = []
for slide_num, slide in enumerate(prs.slides, start=1):
slide_text = [f"=== Slide {slide_num} ==="]
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text.strip():
slide_text.append(shape.text)
if len(slide_text) > 1: # Has content beyond header
text_parts.append('\n'.join(slide_text))
return '\n\n'.join(text_parts) if text_parts else "(empty presentation)"
except ImportError:
return "[PowerPoint parsing requires python-pptx: pip install python-pptx]"
except Exception as e:
return f"[Error parsing PowerPoint file: {str(e)}]"
def get_file_info(filename: str, file_size: int) -> dict:
"""Get human-readable file information."""
ext = '.' + filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
# File type categories
type_categories = {
'document': {'.pdf', '.docx', '.doc', '.txt', '.rtf', '.odt'},
'spreadsheet': {'.xlsx', '.xls', '.csv', '.ods'},
'presentation': {'.pptx', '.ppt', '.odp'},
'image': {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.svg', '.webp'},
'video': {'.mp4', '.avi', '.mov', '.wmv', '.flv', '.mkv', '.webm'},
'audio': {'.mp3', '.wav', '.flac', '.aac', '.ogg', '.wma'},
'archive': {'.zip', '.rar', '.7z', '.tar', '.gz'},
'code': {'.py', '.js', '.java', '.cpp', '.c', '.h', '.cs', '.php', '.rb'},
'web': {'.html', '.css', '.xml', '.json', '.yaml', '.yml'},
}
category = 'file'
for cat, extensions in type_categories.items():
if ext in extensions:
category = cat
break
# Format file size
if file_size < 1024:
size_str = f"{file_size} B"
elif file_size < 1024 * 1024:
size_str = f"{file_size / 1024:.1f} KB"
elif file_size < 1024 * 1024 * 1024:
size_str = f"{file_size / (1024 * 1024):.1f} MB"
else:
size_str = f"{file_size / (1024 * 1024 * 1024):.1f} GB"
return {
'extension': ext,
'category': category,
'size_formatted': size_str,
'size_bytes': file_size
}

23
iam-task-role-policy.json Normal file
View File

@@ -0,0 +1,23 @@
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"dynamodb:CreateTable",
"dynamodb:DescribeTable",
"dynamodb:GetItem",
"dynamodb:PutItem",
"dynamodb:UpdateItem",
"dynamodb:Query",
"dynamodb:Scan",
"dynamodb:UpdateTimeToLive",
"dynamodb:DescribeTimeToLive"
],
"Resource": [
"arn:aws:dynamodb:ap-southeast-2:*:table/prod_*",
"arn:aws:dynamodb:ap-southeast-2:*:table/prod_*/index/*"
]
}
]
}

324
llm_client.py Normal file
View File

@@ -0,0 +1,324 @@
"""
LLM Client Abstraction Layer
Supports multiple LLM providers with easy swapping
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Optional
import requests
import json
import os
class LLMClient(ABC):
"""Abstract base class for LLM clients"""
@abstractmethod
def chat(self, messages: List[Dict[str, str]], context: Optional[str] = None) -> str:
"""
Send messages to LLM and get response
Args:
messages: List of {"role": "user"/"assistant", "content": "text"}
context: Optional document context to include
Returns:
LLM response text
"""
pass
@abstractmethod
def is_available(self) -> bool:
"""Check if LLM service is available"""
pass
class OllamaClient(LLMClient):
"""Ollama LLM client (local deployment)"""
def __init__(self,
base_url: str = "http://localhost:11434",
model: str = "llama3.2",
timeout: int = 120):
"""
Initialize Ollama client
Args:
base_url: Ollama server URL
model: Model name (e.g., llama3.2, mistral, codellama)
timeout: Request timeout in seconds
"""
self.base_url = base_url.rstrip('/')
self.model = model
self.timeout = timeout
def chat(self, messages: List[Dict[str, str]], context: Optional[str] = None) -> str:
"""Send chat request to Ollama"""
# Build prompt with context if provided
formatted_messages = []
if context:
# Add system message with document context
formatted_messages.append({
"role": "system",
"content": f"You are a helpful assistant. Use the following document content to answer questions:\n\n{context}"
})
# Add conversation history
formatted_messages.extend(messages)
# Call Ollama API
try:
response = requests.post(
f"{self.base_url}/api/chat",
json={
"model": self.model,
"messages": formatted_messages,
"stream": False
},
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
return result["message"]["content"]
except requests.exceptions.RequestException as e:
raise Exception(f"Ollama API error: {str(e)}")
def chat_stream(self, messages: List[Dict[str, str]], context: Optional[str] = None):
"""Send chat request to Ollama with streaming response"""
# Build prompt with context if provided
formatted_messages = []
if context:
# Add system message with document context
formatted_messages.append({
"role": "system",
"content": f"You are a helpful assistant. Use the following document content to answer questions:\n\n{context}"
})
# Add conversation history
formatted_messages.extend(messages)
# Call Ollama API with streaming
try:
response = requests.post(
f"{self.base_url}/api/chat",
json={
"model": self.model,
"messages": formatted_messages,
"stream": True
},
timeout=self.timeout,
stream=True
)
response.raise_for_status()
# Yield chunks as they arrive
for line in response.iter_lines():
if line:
try:
chunk = json.loads(line)
if "message" in chunk and "content" in chunk["message"]:
yield chunk["message"]["content"]
if chunk.get("done", False):
break
except json.JSONDecodeError:
continue
except requests.exceptions.RequestException as e:
raise Exception(f"Ollama API error: {str(e)}")
def is_available(self) -> bool:
"""Check if Ollama is running"""
try:
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
return response.status_code == 200
except:
return False
class OpenAIClient(LLMClient):
"""OpenAI API client (for easy swapping)"""
def __init__(self,
api_key: str,
model: str = "gpt-4",
timeout: int = 120):
"""
Initialize OpenAI client
Args:
api_key: OpenAI API key
model: Model name (e.g., gpt-4, gpt-3.5-turbo)
timeout: Request timeout in seconds
"""
self.api_key = api_key
self.model = model
self.timeout = timeout
def chat(self, messages: List[Dict[str, str]], context: Optional[str] = None) -> str:
"""Send chat request to OpenAI"""
formatted_messages = []
if context:
formatted_messages.append({
"role": "system",
"content": f"You are a helpful assistant. Use the following document content to answer questions:\n\n{context}"
})
formatted_messages.extend(messages)
try:
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"model": self.model,
"messages": formatted_messages
},
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
raise Exception(f"OpenAI API error: {str(e)}")
def is_available(self) -> bool:
"""Check if API key is valid"""
try:
response = requests.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {self.api_key}"},
timeout=5
)
return response.status_code == 200
except:
return False
class AnthropicClient(LLMClient):
"""Anthropic Claude API client (for easy swapping)"""
def __init__(self,
api_key: str,
model: str = "claude-3-5-sonnet-20241022",
timeout: int = 120):
"""
Initialize Anthropic client
Args:
api_key: Anthropic API key
model: Model name
timeout: Request timeout in seconds
"""
self.api_key = api_key
self.model = model
self.timeout = timeout
def chat(self, messages: List[Dict[str, str]], context: Optional[str] = None) -> str:
"""Send chat request to Anthropic"""
formatted_messages = []
system_prompt = None
if context:
system_prompt = f"You are a helpful assistant. Use the following document content to answer questions:\n\n{context}"
formatted_messages.extend(messages)
try:
payload = {
"model": self.model,
"messages": formatted_messages,
"max_tokens": 4096
}
if system_prompt:
payload["system"] = system_prompt
response = requests.post(
"https://api.anthropic.com/v1/messages",
headers={
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"Content-Type": "application/json"
},
json=payload,
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
return result["content"][0]["text"]
except requests.exceptions.RequestException as e:
raise Exception(f"Anthropic API error: {str(e)}")
def is_available(self) -> bool:
"""Check if API key is valid"""
# Anthropic doesn't have a simple health check, so we assume it's available
return len(self.api_key) > 0
def create_llm_client(provider: str = "ollama", **kwargs) -> LLMClient:
"""
Factory function to create LLM client
Args:
provider: LLM provider name ("ollama", "openai", "anthropic")
**kwargs: Provider-specific configuration
Returns:
LLMClient instance
Example:
# Ollama (default)
client = create_llm_client("ollama", model="llama3.2")
# OpenAI
client = create_llm_client("openai", api_key="sk-...", model="gpt-4")
# Anthropic
client = create_llm_client("anthropic", api_key="sk-ant-...", model="claude-3-5-sonnet-20241022")
"""
if provider.lower() == "ollama":
return OllamaClient(
base_url=kwargs.get("base_url", os.getenv("OLLAMA_URL", "http://localhost:11434")),
model=kwargs.get("model", os.getenv("OLLAMA_MODEL", "llama3.2")),
timeout=kwargs.get("timeout", 120)
)
elif provider.lower() == "openai":
api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY"))
if not api_key:
raise ValueError("OpenAI API key required")
return OpenAIClient(
api_key=api_key,
model=kwargs.get("model", "gpt-4"),
timeout=kwargs.get("timeout", 120)
)
elif provider.lower() == "anthropic":
api_key = kwargs.get("api_key", os.getenv("ANTHROPIC_API_KEY"))
if not api_key:
raise ValueError("Anthropic API key required")
return AnthropicClient(
api_key=api_key,
model=kwargs.get("model", "claude-3-5-sonnet-20241022"),
timeout=kwargs.get("timeout", 120)
)
else:
raise ValueError(f"Unknown LLM provider: {provider}")

23
requirements.txt Normal file
View File

@@ -0,0 +1,23 @@
# SharePoint Connector Requirements
# Core dependencies
boto3>=1.34.0
cryptography>=41.0.0
msal>=1.24.0
requests>=2.31.0
# Web framework
flask>=3.0.0
# Production server
gunicorn>=21.2.0
# Document parsing
PyPDF2>=3.0.0
python-docx>=1.0.0
openpyxl>=3.1.0
python-pptx>=0.6.0
# PostgreSQL with pgvector for embeddings
psycopg2-binary>=2.9.0
pgvector>=0.2.0

918
saas_connector_dynamodb.py Normal file
View File

@@ -0,0 +1,918 @@
"""
Enterprise SaaS SharePoint Connector - DynamoDB Edition
Secure multi-tenant SharePoint connector using AWS DynamoDB for storage.
Implements OAuth 2.0 with encrypted token storage, DynamoDB persistence,
and enterprise security best practices.
"""
import os
import secrets
import hashlib
from typing import Optional, Dict, Any, List
from datetime import datetime, timedelta
from dataclasses import dataclass
import json
from decimal import Decimal
import requests
from msal import ConfidentialClientApplication
from cryptography.fernet import Fernet
import boto3
from boto3.dynamodb.conditions import Key, Attr
from botocore.exceptions import ClientError
from flask import Flask, request, redirect, session, jsonify, url_for
from functools import wraps
# ============================================================================
# DYNAMODB HELPER FUNCTIONS
# ============================================================================
def decimal_to_float(obj):
"""Convert DynamoDB Decimal types to float/int for JSON serialization."""
if isinstance(obj, list):
return [decimal_to_float(i) for i in obj]
elif isinstance(obj, dict):
return {k: decimal_to_float(v) for k, v in obj.items()}
elif isinstance(obj, Decimal):
return int(obj) if obj % 1 == 0 else float(obj)
return obj
def datetime_to_iso(dt: datetime) -> str:
"""Convert datetime to ISO 8601 string."""
return dt.isoformat() if dt else None
def iso_to_datetime(iso_string: str) -> Optional[datetime]:
"""Convert ISO 8601 string to datetime."""
try:
return datetime.fromisoformat(iso_string) if iso_string else None
except (ValueError, AttributeError):
return None
# ============================================================================
# ENCRYPTION SERVICE
# ============================================================================
class TokenEncryption:
"""
Handles encryption/decryption of OAuth tokens.
SECURITY: Never store tokens in plain text!
"""
def __init__(self, encryption_key: Optional[str] = None):
"""
Initialize encryption service.
Args:
encryption_key: Base64-encoded Fernet key. If not provided,
uses ENCRYPTION_KEY env var.
"""
if encryption_key:
self.key = encryption_key.encode()
else:
key_str = os.getenv('ENCRYPTION_KEY')
if not key_str:
raise ValueError("ENCRYPTION_KEY environment variable must be set")
self.key = key_str.encode()
self.cipher = Fernet(self.key)
@staticmethod
def generate_key() -> str:
"""Generate a new encryption key. Store this securely!"""
return Fernet.generate_key().decode()
def encrypt(self, plaintext: str) -> str:
"""Encrypt a string."""
if not plaintext:
return ""
return self.cipher.encrypt(plaintext.encode()).decode()
def decrypt(self, ciphertext: str) -> str:
"""Decrypt a string."""
if not ciphertext:
return ""
return self.cipher.decrypt(ciphertext.encode()).decode()
# ============================================================================
# DYNAMODB DATA MODELS
# ============================================================================
@dataclass
class SharePointConnectionInfo:
"""Connection information for a SharePoint connection."""
id: str
user_id: str
organization_id: Optional[str]
connection_name: Optional[str]
tenant_id: str
microsoft_user_id: str
is_active: bool
created_at: datetime
last_used_at: Optional[datetime]
# ============================================================================
# DYNAMODB CONNECTOR
# ============================================================================
class DynamoDBSharePointConnector:
"""
Enterprise-grade SharePoint connector using DynamoDB for storage.
DynamoDB Tables:
1. sharepoint_connections - Stores encrypted OAuth tokens
2. sharepoint_oauth_states - Temporary OAuth state for CSRF protection
3. sharepoint_audit_logs - Audit trail for compliance
Features:
- Multi-tenant OAuth 2.0
- Encrypted token storage
- Automatic token refresh
- Audit logging
- CSRF protection
- Serverless-friendly (no connection pooling)
"""
def __init__(
self,
client_id: str,
client_secret: str,
redirect_uri: str,
encryption_key: str,
aws_region: str = "us-east-1",
dynamodb_endpoint: Optional[str] = None, # For local development
table_prefix: str = "",
tenant_id: str = "common"
):
"""
Initialize secure DynamoDB connector.
Args:
client_id: Azure AD application ID
client_secret: Azure AD client secret
redirect_uri: OAuth callback URL (must be HTTPS in production!)
encryption_key: Fernet encryption key for token storage
aws_region: AWS region for DynamoDB
dynamodb_endpoint: Custom endpoint (e.g., http://localhost:8000 for local)
table_prefix: Prefix for table names (e.g., "prod_" or "dev_")
tenant_id: Azure AD tenant or "common" for multi-tenant
"""
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.tenant_id = tenant_id
self.table_prefix = table_prefix
# Required scopes for SharePoint access
# Note: offline_access is handled automatically by MSAL and should NOT be included
self.scopes = [
"https://graph.microsoft.com/Sites.Read.All",
"https://graph.microsoft.com/Files.Read.All",
"https://graph.microsoft.com/User.Read"
]
# Initialize encryption
self.encryption = TokenEncryption(encryption_key)
# Initialize DynamoDB
if dynamodb_endpoint:
# Local development
self.dynamodb = boto3.resource('dynamodb',
region_name=aws_region,
endpoint_url=dynamodb_endpoint)
else:
# Production
self.dynamodb = boto3.resource('dynamodb', region_name=aws_region)
# Table names
self.connections_table_name = f"{table_prefix}sharepoint_connections"
self.oauth_states_table_name = f"{table_prefix}sharepoint_oauth_states"
self.audit_logs_table_name = f"{table_prefix}sharepoint_audit_logs"
# Initialize tables
self._ensure_tables_exist()
# Initialize MSAL
self.authority = f"https://login.microsoftonline.com/{tenant_id}"
self.msal_app = ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_secret,
authority=self.authority
)
def _ensure_tables_exist(self):
"""Create DynamoDB tables if they don't exist."""
try:
# Connections table
self.connections_table = self.dynamodb.Table(self.connections_table_name)
self.connections_table.load()
except ClientError:
self.connections_table = self._create_connections_table()
try:
# OAuth states table
self.oauth_states_table = self.dynamodb.Table(self.oauth_states_table_name)
self.oauth_states_table.load()
except ClientError:
self.oauth_states_table = self._create_oauth_states_table()
try:
# Audit logs table
self.audit_logs_table = self.dynamodb.Table(self.audit_logs_table_name)
self.audit_logs_table.load()
except ClientError:
self.audit_logs_table = self._create_audit_logs_table()
def _create_connections_table(self):
"""Create SharePoint connections table."""
table = self.dynamodb.create_table(
TableName=self.connections_table_name,
KeySchema=[
{'AttributeName': 'id', 'KeyType': 'HASH'} # Partition key
],
AttributeDefinitions=[
{'AttributeName': 'id', 'AttributeType': 'S'},
{'AttributeName': 'user_id', 'AttributeType': 'S'},
{'AttributeName': 'organization_id', 'AttributeType': 'S'}
],
GlobalSecondaryIndexes=[
{
'IndexName': 'user_id-index',
'KeySchema': [
{'AttributeName': 'user_id', 'KeyType': 'HASH'}
],
'Projection': {'ProjectionType': 'ALL'},
'ProvisionedThroughput': {
'ReadCapacityUnits': 5,
'WriteCapacityUnits': 5
}
},
{
'IndexName': 'organization_id-index',
'KeySchema': [
{'AttributeName': 'organization_id', 'KeyType': 'HASH'}
],
'Projection': {'ProjectionType': 'ALL'},
'ProvisionedThroughput': {
'ReadCapacityUnits': 5,
'WriteCapacityUnits': 5
}
}
],
BillingMode='PAY_PER_REQUEST' # On-demand pricing (or use ProvisionedThroughput)
)
table.wait_until_exists()
return table
def _create_oauth_states_table(self):
"""Create OAuth states table with TTL for automatic cleanup."""
table = self.dynamodb.create_table(
TableName=self.oauth_states_table_name,
KeySchema=[
{'AttributeName': 'state', 'KeyType': 'HASH'}
],
AttributeDefinitions=[
{'AttributeName': 'state', 'AttributeType': 'S'}
],
BillingMode='PAY_PER_REQUEST'
)
table.wait_until_exists()
# Enable TTL for automatic cleanup
try:
self.dynamodb.meta.client.update_time_to_live(
TableName=self.oauth_states_table_name,
TimeToLiveSpecification={
'Enabled': True,
'AttributeName': 'ttl'
}
)
except ClientError:
pass # TTL already enabled or not supported
return table
def _create_audit_logs_table(self):
"""Create audit logs table."""
table = self.dynamodb.create_table(
TableName=self.audit_logs_table_name,
KeySchema=[
{'AttributeName': 'connection_id', 'KeyType': 'HASH'},
{'AttributeName': 'timestamp', 'KeyType': 'RANGE'} # Sort key
],
AttributeDefinitions=[
{'AttributeName': 'connection_id', 'AttributeType': 'S'},
{'AttributeName': 'timestamp', 'AttributeType': 'S'},
{'AttributeName': 'user_id', 'AttributeType': 'S'}
],
GlobalSecondaryIndexes=[
{
'IndexName': 'user_id-timestamp-index',
'KeySchema': [
{'AttributeName': 'user_id', 'KeyType': 'HASH'},
{'AttributeName': 'timestamp', 'KeyType': 'RANGE'}
],
'Projection': {'ProjectionType': 'ALL'},
'ProvisionedThroughput': {
'ReadCapacityUnits': 5,
'WriteCapacityUnits': 5
}
}
],
BillingMode='PAY_PER_REQUEST'
)
table.wait_until_exists()
return table
def initiate_connection(
self,
user_id: str,
organization_id: Optional[str] = None,
return_url: Optional[str] = None
) -> str:
"""
Initiate OAuth flow for a user to connect their SharePoint.
Args:
user_id: Your SaaS user ID
organization_id: Your SaaS organization ID (if applicable)
return_url: URL to redirect to after successful connection
Returns:
Authorization URL to redirect user to
"""
# Generate secure state token
state = secrets.token_urlsafe(32)
# Store state in DynamoDB for CSRF protection
expires_at = datetime.utcnow() + timedelta(minutes=10)
ttl = int(expires_at.timestamp()) # Unix timestamp for DynamoDB TTL
self.oauth_states_table.put_item(
Item={
'state': state,
'user_id': user_id,
'organization_id': organization_id or '',
'return_url': return_url or '',
'expires_at': datetime_to_iso(expires_at),
'ttl': ttl, # DynamoDB will auto-delete after this time
'used': False
}
)
# Generate authorization URL
auth_url = self.msal_app.get_authorization_request_url(
scopes=self.scopes,
state=state,
redirect_uri=self.redirect_uri
)
return auth_url
def complete_connection(
self,
auth_code: str,
state: str,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None
) -> SharePointConnectionInfo:
"""
Complete OAuth flow and store connection.
Args:
auth_code: Authorization code from OAuth callback
state: State parameter from OAuth callback
ip_address: User's IP address (for audit log)
user_agent: User's user agent (for audit log)
Returns:
Connection information
Raises:
ValueError: If state is invalid or expired
Exception: If token acquisition fails
"""
# Validate state (CSRF protection)
try:
response = self.oauth_states_table.get_item(Key={'state': state})
oauth_state = response.get('Item')
except ClientError:
oauth_state = None
if not oauth_state:
raise ValueError("Invalid OAuth state")
# Check if state is valid
expires_at = iso_to_datetime(oauth_state.get('expires_at'))
if oauth_state.get('used') or datetime.utcnow() > expires_at:
raise ValueError("OAuth state expired or already used")
# Mark state as used
self.oauth_states_table.update_item(
Key={'state': state},
UpdateExpression='SET used = :val',
ExpressionAttributeValues={':val': True}
)
# Exchange code for tokens
token_response = self.msal_app.acquire_token_by_authorization_code(
code=auth_code,
scopes=self.scopes,
redirect_uri=self.redirect_uri
)
if "error" in token_response:
raise Exception(f"Token acquisition failed: {token_response.get('error_description', token_response['error'])}")
# Get user info from Microsoft
user_info = self._get_user_info(token_response["access_token"])
# Calculate token expiry
expires_in = token_response.get("expires_in", 3600)
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
# Encrypt tokens
encrypted_access_token = self.encryption.encrypt(token_response["access_token"])
encrypted_refresh_token = self.encryption.encrypt(token_response.get("refresh_token", ""))
# Create connection record
user_id = oauth_state['user_id']
organization_id = oauth_state.get('organization_id') or None
connection_id = self._generate_connection_id(user_id, user_info["id"])
now = datetime.utcnow()
connection_item = {
'id': connection_id,
'user_id': user_id,
'organization_id': organization_id or '',
'tenant_id': user_info.get('tenantId', self.tenant_id),
'microsoft_user_id': user_info["id"],
'connection_name': f"{user_info.get('displayName', 'SharePoint')} - {user_info.get('userPrincipalName', '')}",
'encrypted_access_token': encrypted_access_token,
'encrypted_refresh_token': encrypted_refresh_token,
'token_expires_at': datetime_to_iso(expires_at),
'scopes': json.dumps(self.scopes),
'is_active': True,
'created_at': datetime_to_iso(now),
'updated_at': datetime_to_iso(now),
'last_used_at': datetime_to_iso(now)
}
# Store connection
self.connections_table.put_item(Item=connection_item)
# Audit log
self._log_activity(
connection_id=connection_id,
user_id=user_id,
action="connection_created",
status="success",
ip_address=ip_address,
user_agent=user_agent
)
return SharePointConnectionInfo(
id=connection_id,
user_id=user_id,
organization_id=organization_id,
connection_name=connection_item['connection_name'],
tenant_id=connection_item['tenant_id'],
microsoft_user_id=user_info["id"],
is_active=True,
created_at=now,
last_used_at=now
)
def get_valid_token(
self,
connection_id: str,
user_id: str
) -> str:
"""
Get a valid access token, refreshing if necessary.
Args:
connection_id: SharePoint connection ID
user_id: Your SaaS user ID (for authorization check)
Returns:
Valid access token
Raises:
ValueError: If connection not found or doesn't belong to user
Exception: If token refresh fails
"""
# Get connection (with authorization check!)
try:
response = self.connections_table.get_item(Key={'id': connection_id})
connection = response.get('Item')
except ClientError:
connection = None
if not connection or connection.get('user_id') != user_id or not connection.get('is_active'):
raise ValueError("Connection not found or access denied")
# Check if token needs refresh
token_expires_at = iso_to_datetime(connection['token_expires_at'])
if datetime.utcnow() >= token_expires_at - timedelta(minutes=5):
# Refresh token
refresh_token = self.encryption.decrypt(connection['encrypted_refresh_token'])
token_response = self.msal_app.acquire_token_by_refresh_token(
refresh_token=refresh_token,
scopes=self.scopes
)
if "error" in token_response:
self._log_activity(
connection_id=connection_id,
user_id=user_id,
action="token_refresh",
status="failure",
details=json.dumps({"error": token_response.get("error")})
)
raise Exception(f"Token refresh failed: {token_response.get('error_description', token_response['error'])}")
# Update stored tokens
new_access_token = self.encryption.encrypt(token_response["access_token"])
new_refresh_token = token_response.get("refresh_token")
expires_in = token_response.get("expires_in", 3600)
new_expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
update_expression = "SET encrypted_access_token = :access, token_expires_at = :expires, updated_at = :updated, last_used_at = :used"
expression_values = {
':access': new_access_token,
':expires': datetime_to_iso(new_expires_at),
':updated': datetime_to_iso(datetime.utcnow()),
':used': datetime_to_iso(datetime.utcnow())
}
if new_refresh_token:
update_expression += ", encrypted_refresh_token = :refresh"
expression_values[':refresh'] = self.encryption.encrypt(new_refresh_token)
self.connections_table.update_item(
Key={'id': connection_id},
UpdateExpression=update_expression,
ExpressionAttributeValues=expression_values
)
self._log_activity(
connection_id=connection_id,
user_id=user_id,
action="token_refresh",
status="success"
)
connection['encrypted_access_token'] = new_access_token
else:
# Update last used timestamp
self.connections_table.update_item(
Key={'id': connection_id},
UpdateExpression='SET last_used_at = :val',
ExpressionAttributeValues={':val': datetime_to_iso(datetime.utcnow())}
)
# Decrypt and return access token
return self.encryption.decrypt(connection['encrypted_access_token'])
def list_connections(
self,
user_id: str,
organization_id: Optional[str] = None
) -> List[SharePointConnectionInfo]:
"""
List all SharePoint connections for a user or organization.
Args:
user_id: Your SaaS user ID
organization_id: Your SaaS organization ID (optional)
Returns:
List of connections
"""
# Query by user_id using GSI
response = self.connections_table.query(
IndexName='user_id-index',
KeyConditionExpression=Key('user_id').eq(user_id),
FilterExpression=Attr('is_active').eq(True)
)
connections = response.get('Items', [])
# Filter by organization if specified
if organization_id:
connections = [c for c in connections if c.get('organization_id') == organization_id]
return [
SharePointConnectionInfo(
id=conn['id'],
user_id=conn['user_id'],
organization_id=conn.get('organization_id') or None,
connection_name=conn.get('connection_name'),
tenant_id=conn.get('tenant_id'),
microsoft_user_id=conn.get('microsoft_user_id'),
is_active=conn.get('is_active', False),
created_at=iso_to_datetime(conn.get('created_at')),
last_used_at=iso_to_datetime(conn.get('last_used_at'))
)
for conn in connections
]
def disconnect(
self,
connection_id: str,
user_id: str,
ip_address: Optional[str] = None
):
"""
Disconnect (deactivate) a SharePoint connection.
Args:
connection_id: SharePoint connection ID
user_id: Your SaaS user ID (for authorization)
ip_address: User's IP for audit log
"""
# Get connection to verify ownership
try:
response = self.connections_table.get_item(Key={'id': connection_id})
connection = response.get('Item')
except ClientError:
connection = None
if not connection or connection.get('user_id') != user_id:
raise ValueError("Connection not found or access denied")
# Deactivate connection
self.connections_table.update_item(
Key={'id': connection_id},
UpdateExpression='SET is_active = :val, updated_at = :updated',
ExpressionAttributeValues={
':val': False,
':updated': datetime_to_iso(datetime.utcnow())
}
)
self._log_activity(
connection_id=connection_id,
user_id=user_id,
action="connection_disconnected",
status="success",
ip_address=ip_address
)
def _get_user_info(self, access_token: str) -> Dict[str, Any]:
"""Get Microsoft user information."""
headers = {"Authorization": f"Bearer {access_token}"}
response = requests.get(
"https://graph.microsoft.com/v1.0/me?$select=id,displayName,userPrincipalName,mail",
headers=headers
)
response.raise_for_status()
return response.json()
def _generate_connection_id(self, user_id: str, microsoft_user_id: str) -> str:
"""Generate deterministic connection ID."""
combined = f"{user_id}:{microsoft_user_id}"
return hashlib.sha256(combined.encode()).hexdigest()[:32]
def _log_activity(
self,
connection_id: str,
user_id: str,
action: str,
status: str,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
details: Optional[str] = None
):
"""Create audit log entry."""
timestamp = datetime.utcnow()
self.audit_logs_table.put_item(
Item={
'connection_id': connection_id,
'timestamp': datetime_to_iso(timestamp),
'user_id': user_id,
'action': action,
'status': status,
'ip_address': ip_address or '',
'user_agent': user_agent or '',
'details': details or ''
}
)
# ============================================================================
# SHAREPOINT API CLIENT
# ============================================================================
class SecureSharePointClient:
"""
SharePoint API client with automatic token management.
Use this to make SharePoint API calls on behalf of connected users.
"""
def __init__(self, connector: DynamoDBSharePointConnector, connection_id: str, user_id: str):
"""
Initialize client.
Args:
connector: DynamoDBSharePointConnector instance
connection_id: SharePoint connection ID
user_id: Your SaaS user ID
"""
self.connector = connector
self.connection_id = connection_id
self.user_id = user_id
def _get_headers(self) -> Dict[str, str]:
"""Get headers with valid access token."""
access_token = self.connector.get_valid_token(self.connection_id, self.user_id)
return {
"Authorization": f"Bearer {access_token}",
"Accept": "application/json"
}
def list_sites(self) -> List[Dict[str, Any]]:
"""List SharePoint sites the user has access to."""
response = requests.get(
"https://graph.microsoft.com/v1.0/sites?search=*",
headers=self._get_headers()
)
response.raise_for_status()
return response.json().get("value", [])
def read_file(self, site_id: str, file_path: str, as_text: bool = True) -> Any:
"""Read file content."""
encoded_path = requests.utils.quote(file_path)
url = f"https://graph.microsoft.com/v1.0/sites/{site_id}/drive/root:/{encoded_path}:/content"
response = requests.get(url, headers=self._get_headers())
response.raise_for_status()
return response.text if as_text else response.content
def list_files(self, site_id: str, folder_path: str = "") -> List[Dict[str, Any]]:
"""List files in a folder."""
if folder_path:
encoded_path = requests.utils.quote(folder_path)
url = f"https://graph.microsoft.com/v1.0/sites/{site_id}/drive/root:/{encoded_path}:/children"
else:
url = f"https://graph.microsoft.com/v1.0/sites/{site_id}/drive/root/children"
response = requests.get(url, headers=self._get_headers())
response.raise_for_status()
return response.json().get("value", [])
def search_files(self, site_id: str, query: str) -> List[Dict[str, Any]]:
"""Search for files."""
url = f"https://graph.microsoft.com/v1.0/sites/{site_id}/drive/root/search(q='{query}')"
response = requests.get(url, headers=self._get_headers())
response.raise_for_status()
return response.json().get("value", [])
# ============================================================================
# FLASK INTEGRATION EXAMPLE
# ============================================================================
def create_app(connector: DynamoDBSharePointConnector) -> Flask:
"""Create Flask app with DynamoDB SharePoint connector integration."""
app = Flask(__name__)
app.secret_key = os.getenv("FLASK_SECRET_KEY", secrets.token_urlsafe(32))
def require_auth(f):
"""Decorator to require authentication."""
@wraps(f)
def decorated_function(*args, **kwargs):
if "user_id" not in session:
return jsonify({"error": "Authentication required"}), 401
return f(*args, **kwargs)
return decorated_function
@app.route("/sharepoint/connect")
@require_auth
def connect_sharepoint():
"""Initiate SharePoint connection."""
user_id = session["user_id"]
organization_id = session.get("organization_id")
auth_url = connector.initiate_connection(
user_id=user_id,
organization_id=organization_id,
return_url=request.args.get("return_url", "/dashboard")
)
return redirect(auth_url)
@app.route("/sharepoint/callback")
def sharepoint_callback():
"""OAuth callback endpoint."""
if "error" in request.args:
return jsonify({
"error": request.args.get("error_description", request.args["error"])
}), 400
auth_code = request.args.get("code")
state = request.args.get("state")
if not auth_code or not state:
return jsonify({"error": "Invalid callback"}), 400
try:
connection_info = connector.complete_connection(
auth_code=auth_code,
state=state,
ip_address=request.remote_addr,
user_agent=request.headers.get("User-Agent")
)
session["sharepoint_connection_id"] = connection_info.id
return jsonify({
"success": True,
"connection": {
"id": connection_info.id,
"name": connection_info.connection_name
}
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/api/sharepoint/connections")
@require_auth
def list_connections():
"""List user's SharePoint connections."""
user_id = session["user_id"]
connections = connector.list_connections(user_id)
return jsonify({
"connections": [
{
"id": conn.id,
"name": conn.connection_name,
"created_at": conn.created_at.isoformat(),
"last_used_at": conn.last_used_at.isoformat() if conn.last_used_at else None
}
for conn in connections
]
})
@app.route("/api/sharepoint/connections/<connection_id>/disconnect", methods=["POST"])
@require_auth
def disconnect_sharepoint(connection_id):
"""Disconnect SharePoint."""
user_id = session["user_id"]
try:
connector.disconnect(
connection_id=connection_id,
user_id=user_id,
ip_address=request.remote_addr
)
return jsonify({"success": True})
except Exception as e:
return jsonify({"error": str(e)}), 400
@app.route("/api/sharepoint/<connection_id>/sites")
@require_auth
def get_sites(connection_id):
"""Get SharePoint sites."""
user_id = session["user_id"]
try:
client = SecureSharePointClient(connector, connection_id, user_id)
sites = client.list_sites()
return jsonify({"sites": decimal_to_float(sites)})
except Exception as e:
return jsonify({"error": str(e)}), 500
return app
if __name__ == "__main__":
# Initialize connector
connector = DynamoDBSharePointConnector(
client_id=os.getenv("SHAREPOINT_CLIENT_ID"),
client_secret=os.getenv("SHAREPOINT_CLIENT_SECRET"),
redirect_uri=os.getenv("REDIRECT_URI"),
encryption_key=os.getenv("ENCRYPTION_KEY"),
aws_region=os.getenv("AWS_REGION", "us-east-1"),
table_prefix=os.getenv("TABLE_PREFIX", ""),
tenant_id=os.getenv("SHAREPOINT_TENANT_ID", "common")
)
# Create and run Flask app
app = create_app(connector)
app.run(debug=False, port=5000)

23
setup_agent.py Normal file
View File

@@ -0,0 +1,23 @@
import os
from dotenv import load_dotenv
from toothfairy_client import ToothFairyClient
# 1. Manually load the .env file so the script can see your keys
load_dotenv()
# 2. Initialize the client
client = ToothFairyClient()
# 3. Create the Tool pulling the URL from .env
ngrok_url = os.getenv("NGROK_URL")
if not ngrok_url:
print("❌ Error: NGROK_URL not found in .env file. Please add it and try again.")
exit(1)
tool_id = client.create_search_tool(ngrok_url)
print(f"✅ Tool created with ID: {tool_id}")
# 4. Create the Agent and attach the Tool
agent_id = client.create_agent(tool_id)
print(f"✅ Agent created with ID: {agent_id}")

30
sonar-project.properties Normal file
View File

@@ -0,0 +1,30 @@
# SonarQube Project Configuration for SharePoint Connector
# Project identification
sonar.projectKey=sharepoint-connector
sonar.projectName=SharePoint Connector Plugin
sonar.projectVersion=1.0
# Source code
sonar.sources=.
sonar.sourceEncoding=UTF-8
# Tests
sonar.tests=.
sonar.test.inclusions=test_*.py
# Exclusions
sonar.exclusions=venv/**,**/__pycache__/**,*.pyc,.venv/**,htmlcov/**,templates/**,static/**,**/migrations/**
# Python specific
sonar.python.version=3.11
sonar.python.coverage.reportPaths=coverage.xml
# Coverage exclusions (files that don't need coverage)
sonar.coverage.exclusions=test_*.py,**/__init__.py,**/migrations/**,clear_data.py
# Duplications
sonar.cpd.exclusions=test_*.py
# Analysis parameters
sonar.scm.provider=git

1310
static/style.css Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,155 @@
"""
Secure Credentials Storage
Encrypts and stores SharePoint OAuth credentials to disk for persistence across restarts.
Uses Fernet symmetric encryption to protect sensitive data.
"""
import json
import os
from typing import Dict, Optional
from cryptography.fernet import Fernet
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
class CredentialsStorage:
"""
Secure storage for SharePoint OAuth credentials.
Stores credentials encrypted on disk at ~/.sharepoint_credentials/credentials.enc
Each user's credentials are stored separately and encrypted.
"""
def __init__(self, storage_dir: str = None):
"""
Initialize credentials storage.
Args:
storage_dir: Directory to store credentials (default: ~/.sharepoint_credentials)
"""
if storage_dir is None:
storage_dir = os.path.expanduser("~/.sharepoint_credentials")
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(parents=True, exist_ok=True)
self.credentials_file = self.storage_dir / "credentials.enc"
self.key_file = self.storage_dir / "key.key"
# Initialize encryption key
self.cipher = self._get_or_create_cipher()
# Load existing credentials
self.credentials: Dict[str, Dict] = self._load_credentials()
def _get_or_create_cipher(self) -> Fernet:
"""Get or create encryption key."""
if self.key_file.exists():
with open(self.key_file, 'rb') as f:
key = f.read()
else:
key = Fernet.generate_key()
with open(self.key_file, 'wb') as f:
f.write(key)
# Set restrictive permissions on key file
os.chmod(self.key_file, 0o600)
return Fernet(key)
def _load_credentials(self) -> Dict[str, Dict]:
"""Load credentials from disk."""
if not self.credentials_file.exists():
return {}
try:
with open(self.credentials_file, 'rb') as f:
encrypted_data = f.read()
if not encrypted_data:
return {}
decrypted_data = self.cipher.decrypt(encrypted_data)
return json.loads(decrypted_data.decode('utf-8'))
except Exception as e:
logger.error(f"Failed to load credentials: {e}")
return {}
def _save_credentials(self):
"""Save credentials to disk."""
try:
data = json.dumps(self.credentials).encode('utf-8')
encrypted_data = self.cipher.encrypt(data)
with open(self.credentials_file, 'wb') as f:
f.write(encrypted_data)
# Set restrictive permissions on credentials file
os.chmod(self.credentials_file, 0o600)
logger.info("Credentials saved to disk")
except Exception as e:
logger.error(f"Failed to save credentials: {e}")
raise
def save_config(self, user_id: str, config: Dict) -> None:
"""
Save user configuration.
Args:
user_id: User identifier
config: Configuration dictionary containing:
- tenant_id: Azure tenant ID
- client_id: Azure client ID
- client_secret: Azure client secret
"""
self.credentials[user_id] = config
self._save_credentials()
logger.info(f"Saved credentials for user {user_id}")
def get_config(self, user_id: str) -> Optional[Dict]:
"""
Get user configuration.
Args:
user_id: User identifier
Returns:
Configuration dictionary or None if not found
"""
return self.credentials.get(user_id)
def delete_config(self, user_id: str) -> bool:
"""
Delete user configuration.
Args:
user_id: User identifier
Returns:
True if deleted, False if not found
"""
if user_id in self.credentials:
del self.credentials[user_id]
self._save_credentials()
logger.info(f"Deleted credentials for user {user_id}")
return True
return False
def list_users(self) -> list:
"""Get list of users with stored credentials."""
return list(self.credentials.keys())
# Global credentials storage instance
_credentials_storage = None
def get_credentials_storage() -> CredentialsStorage:
"""Get the global credentials storage instance."""
global _credentials_storage
if _credentials_storage is None:
_credentials_storage = CredentialsStorage()
return _credentials_storage

57
task-definition.json Normal file
View File

@@ -0,0 +1,57 @@
{
"family": "sharepoint-connector",
"networkMode": "awsvpc",
"requiresCompatibilities": ["FARGATE"],
"cpu": "512",
"memory": "1024",
"executionRoleArn": "arn:aws:iam::YOUR_ACCOUNT_ID:role/ecsTaskExecutionRole",
"taskRoleArn": "arn:aws:iam::YOUR_ACCOUNT_ID:role/sharepoint-connector-task-role",
"containerDefinitions": [
{
"name": "sharepoint-connector",
"image": "YOUR_ACCOUNT_ID.dkr.ecr.ap-southeast-2.amazonaws.com/sharepoint-connector:latest",
"essential": true,
"portMappings": [
{
"containerPort": 8000,
"protocol": "tcp"
}
],
"environment": [
{
"name": "AWS_REGION",
"value": "ap-southeast-2"
},
{
"name": "TABLE_PREFIX",
"value": "prod_"
},
{
"name": "PORT",
"value": "8000"
}
],
"secrets": [
{
"name": "FLASK_SECRET_KEY",
"valueFrom": "arn:aws:secretsmanager:ap-southeast-2:YOUR_ACCOUNT_ID:secret:sharepoint-connector/flask-secret"
}
],
"healthCheck": {
"command": ["CMD-SHELL", "python -c \"import requests; requests.get('http://localhost:8000/health')\""],
"interval": 30,
"timeout": 5,
"retries": 3,
"startPeriod": 60
},
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": "/ecs/sharepoint-connector",
"awslogs-region": "ap-southeast-2",
"awslogs-stream-prefix": "ecs"
}
}
}
]
}

25
templates/error.html Normal file
View File

@@ -0,0 +1,25 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Error - SharePoint Connector</title>
<link rel="stylesheet" href="/static/style.css">
</head>
<body>
<div class="container">
<header>
<h1>❌ Error</h1>
</header>
<div class="status-card">
<div class="error">
<h3>Something went wrong</h3>
<p>{{ error }}</p>
<br>
<a href="/" class="btn btn-primary">Go Home</a>
</div>
</div>
</div>
</body>
</html>

1417
templates/index.html Normal file

File diff suppressed because it is too large Load Diff

286
test_app.py Normal file
View File

@@ -0,0 +1,286 @@
"""
Unit tests for app.py Flask routes
"""
import pytest
import json
from unittest.mock import Mock, patch, MagicMock
from app import app
@pytest.fixture
def client():
"""Create a test client for the Flask app."""
app.config['TESTING'] = True
app.config['SECRET_KEY'] = 'test-secret-key'
with app.test_client() as client:
yield client
@pytest.fixture
def mock_session():
"""Create a mock session."""
return {"user_id": "test_user_123"}
class TestHealthCheck:
"""Test health check endpoint."""
def test_health_endpoint(self, client):
"""Test /health endpoint returns healthy status."""
response = client.get('/health')
assert response.status_code == 200
data = json.loads(response.data)
assert data['status'] == 'healthy'
class TestIndexRoute:
"""Test index route."""
def test_index_returns_template(self, client):
"""Test index route returns 200."""
with patch('app.render_template') as mock_render:
mock_render.return_value = '<html>Test</html>'
response = client.get('/')
assert response.status_code == 200
class TestConfigurationAPI:
"""Test configuration API endpoints."""
def test_check_config_not_configured(self, client):
"""Test check_config returns false when not configured."""
response = client.get('/api/config/check')
assert response.status_code == 200
data = json.loads(response.data)
assert 'configured' in data
def test_save_config_missing_credentials(self, client):
"""Test save_config fails without credentials."""
response = client.post(
'/api/config/save',
data=json.dumps({}),
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
@patch('app.credentials_storage')
def test_save_config_success(self, mock_storage, client):
"""Test save_config succeeds with valid credentials."""
mock_storage.get_config.return_value = None
mock_storage.save_config.return_value = None
response = client.post(
'/api/config/save',
data=json.dumps({
'client_id': 'test-client-id',
'client_secret': 'test-client-secret',
'tenant_id': 'common'
}),
content_type='application/json'
)
assert response.status_code == 200
data = json.loads(response.data)
assert data['success'] is True
@patch('app.credentials_storage')
def test_reset_config(self, mock_storage, client):
"""Test reset_config clears configuration."""
mock_storage.delete_config.return_value = None
response = client.post('/api/config/reset')
assert response.status_code == 200
data = json.loads(response.data)
assert data['success'] is True
class TestSharePointAPI:
"""Test SharePoint API endpoints."""
def test_list_connections_not_configured(self, client):
"""Test list_connections fails without configuration."""
response = client.get('/api/sharepoint/connections')
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
def test_get_sites_missing_connection_id(self, client):
"""Test get_sites requires valid connection_id."""
response = client.get('/api/sharepoint/invalid_conn/sites')
assert response.status_code == 400
@patch('app.get_or_create_connector')
def test_get_files_missing_site_id(self, mock_connector, client):
"""Test get_files requires site_id parameter."""
mock_connector.return_value = Mock()
response = client.get('/api/sharepoint/test_conn/files')
assert response.status_code == 400
data = json.loads(response.data)
assert 'site_id is required' in data['error']
@patch('app.get_or_create_connector')
def test_read_file_missing_parameters(self, mock_connector, client):
"""Test read_file requires site_id and file_path."""
mock_connector.return_value = Mock()
response = client.get('/api/sharepoint/test_conn/read')
assert response.status_code == 400
data = json.loads(response.data)
assert 'required' in data['error']
class TestChatAPI:
"""Test chat API endpoints."""
def test_chat_send_missing_parameters(self, client):
"""Test chat_send requires all parameters."""
response = client.post(
'/api/chat/send',
data=json.dumps({}),
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'required' in data['error']
def test_chat_send_no_document_loaded(self, client):
"""Test chat_send fails without loaded document."""
response = client.post(
'/api/chat/send',
data=json.dumps({
'site_id': 'test_site',
'file_path': 'test.txt',
'message': 'Hello'
}),
content_type='application/json'
)
assert response.status_code == 400
data = json.loads(response.data)
assert 'No document loaded' in data['error']
def test_chat_history_missing_parameters(self, client):
"""Test chat_history requires parameters."""
response = client.get('/api/chat/history')
assert response.status_code == 400
data = json.loads(response.data)
assert 'required' in data['error']
def test_chat_clear_missing_parameters(self, client):
"""Test chat_clear requires parameters."""
response = client.post(
'/api/chat/clear',
data=json.dumps({}),
content_type='application/json'
)
assert response.status_code == 400
class TestLLMStatus:
"""Test LLM status endpoint."""
@patch('app.llm_client')
def test_llm_status_available(self, mock_llm, client):
"""Test LLM status returns available when service is up."""
mock_llm.is_available.return_value = True
response = client.get('/api/llm/status')
assert response.status_code == 200
data = json.loads(response.data)
assert data['available'] is True
assert 'provider' in data
@patch('app.llm_client')
def test_llm_status_unavailable(self, mock_llm, client):
"""Test LLM status returns unavailable when service is down."""
mock_llm.is_available.side_effect = Exception("Connection failed")
response = client.get('/api/llm/status')
assert response.status_code == 200
data = json.loads(response.data)
assert data['available'] is False
assert 'error' in data
class TestVectorStoreAPI:
"""Test vector store / multi-document chat API."""
@patch('app.vector_store', None)
def test_add_document_vector_store_unavailable(self, client):
"""Test add_document fails when vector store is unavailable."""
response = client.post(
'/api/documents/add',
data=json.dumps({}),
content_type='application/json'
)
assert response.status_code == 503
data = json.loads(response.data)
assert 'Vector store not available' in data['error']
@patch('app.vector_store', None)
def test_list_tags_vector_store_unavailable(self, client):
"""Test list_tags fails when vector store is unavailable."""
response = client.get('/api/documents/tags')
assert response.status_code == 503
@patch('app.vector_store', None)
def test_chat_multi_vector_store_unavailable(self, client):
"""Test chat_multi fails when vector store is unavailable."""
response = client.post(
'/api/chat/multi',
data=json.dumps({'message': 'test'}),
content_type='application/json'
)
assert response.status_code == 503
class TestIndexingAPI:
"""Test background indexing API."""
@patch('app.vector_store', None)
@patch('app.get_or_create_connector')
def test_start_indexing_vector_store_unavailable(self, mock_connector, client):
"""Test start_indexing fails without vector store."""
mock_connector.return_value = Mock()
response = client.post(
'/api/indexing/start',
data=json.dumps({}),
content_type='application/json'
)
assert response.status_code == 503
def test_get_indexing_status_not_found(self, client):
"""Test get_indexing_status returns 404 for unknown job."""
with patch('app.get_indexer') as mock_indexer:
mock_indexer.return_value.get_job_status.return_value = None
response = client.get('/api/indexing/status/unknown_job')
assert response.status_code == 404
class TestErrorHandlers:
"""Test error handlers."""
def test_404_handler(self, client):
"""Test 404 error handler."""
response = client.get('/nonexistent-route')
assert response.status_code == 404
data = json.loads(response.data)
assert 'error' in data
assert data['error'] == 'Not found'
class TestDecorators:
"""Test authentication decorators."""
def test_require_auth_creates_user_id(self, client):
"""Test require_auth decorator creates user_id in session."""
with client.session_transaction() as sess:
assert 'user_id' not in sess
response = client.get('/')
with client.session_transaction() as sess:
assert 'user_id' in sess

145
test_document_parser.py Normal file
View File

@@ -0,0 +1,145 @@
"""
Unit tests for document_parser.py
"""
import pytest
from document_parser import DocumentParser, get_file_info
class TestDocumentParser:
"""Test DocumentParser class."""
def setup_method(self):
"""Set up test fixtures."""
self.parser = DocumentParser()
def test_can_parse_supported_extensions(self):
"""Test can_parse returns True for supported file types."""
supported_files = [
'document.txt', 'readme.md', 'data.csv', 'config.json',
'report.pdf', 'document.docx', 'spreadsheet.xlsx', 'slides.pptx',
'script.py', 'code.java', 'style.css', 'index.html'
]
for filename in supported_files:
assert self.parser.can_parse(filename), f"Should parse {filename}"
def test_can_parse_unsupported_extensions(self):
"""Test can_parse returns False for unsupported file types."""
unsupported_files = [
'image.png', 'video.mp4', 'audio.mp3', 'archive.zip',
'binary.exe', 'document.doc'
]
for filename in unsupported_files:
assert not self.parser.can_parse(filename), f"Should not parse {filename}"
def test_get_extension(self):
"""Test _get_extension method."""
assert self.parser._get_extension('file.txt') == '.txt'
assert self.parser._get_extension('FILE.TXT') == '.txt'
assert self.parser._get_extension('archive.tar.gz') == '.gz'
assert self.parser._get_extension('noextension') == ''
def test_parse_text_utf8(self):
"""Test parsing UTF-8 text files."""
content = "Hello World\nThis is a test".encode('utf-8')
result = self.parser.parse(content, 'test.txt')
assert result == "Hello World\nThis is a test"
def test_parse_text_multiple_encodings(self):
"""Test parsing text with different encodings."""
content = "Test content"
# UTF-8
result = self.parser._parse_text(content.encode('utf-8'))
assert result == "Test content"
# Latin-1
result = self.parser._parse_text(content.encode('latin-1'))
assert result == "Test content"
def test_parse_unsupported_file_raises_error(self):
"""Test parsing unsupported file type raises ValueError."""
with pytest.raises(ValueError, match="Unsupported file type"):
self.parser.parse(b"content", "file.exe")
def test_parse_json(self):
"""Test parsing JSON files."""
content = '{"key": "value", "number": 123}'.encode('utf-8')
result = self.parser.parse(content, 'data.json')
assert '"key": "value"' in result
assert '"number": 123' in result
def test_parse_csv(self):
"""Test parsing CSV files."""
content = "name,age,city\nAlice,30,NYC\nBob,25,LA".encode('utf-8')
result = self.parser.parse(content, 'data.csv')
assert "Alice" in result
assert "30" in result
assert "NYC" in result
class TestGetFileInfo:
"""Test get_file_info function."""
def test_document_category(self):
"""Test document file type categorization."""
info = get_file_info('report.pdf', 1024)
assert info['category'] == 'document'
assert info['extension'] == '.pdf'
def test_spreadsheet_category(self):
"""Test spreadsheet file type categorization."""
info = get_file_info('data.xlsx', 2048)
assert info['category'] == 'spreadsheet'
assert info['extension'] == '.xlsx'
def test_presentation_category(self):
"""Test presentation file type categorization."""
info = get_file_info('slides.pptx', 4096)
assert info['category'] == 'presentation'
def test_code_category(self):
"""Test code file type categorization."""
info = get_file_info('script.py', 512)
assert info['category'] == 'code'
def test_image_category(self):
"""Test image file type categorization."""
info = get_file_info('photo.jpg', 8192)
assert info['category'] == 'image'
def test_file_size_bytes(self):
"""Test file size formatting in bytes."""
info = get_file_info('small.txt', 512)
assert info['size_formatted'] == '512 B'
assert info['size_bytes'] == 512
def test_file_size_kilobytes(self):
"""Test file size formatting in KB."""
info = get_file_info('medium.txt', 2048)
assert 'KB' in info['size_formatted']
assert info['size_bytes'] == 2048
def test_file_size_megabytes(self):
"""Test file size formatting in MB."""
info = get_file_info('large.pdf', 5 * 1024 * 1024)
assert 'MB' in info['size_formatted']
assert '5.0' in info['size_formatted']
def test_file_size_gigabytes(self):
"""Test file size formatting in GB."""
info = get_file_info('huge.zip', 2 * 1024 * 1024 * 1024)
assert 'GB' in info['size_formatted']
assert '2.0' in info['size_formatted']
def test_unknown_extension(self):
"""Test unknown file extension."""
info = get_file_info('file.xyz', 1024)
assert info['category'] == 'file'
assert info['extension'] == '.xyz'
def test_no_extension(self):
"""Test file with no extension."""
info = get_file_info('README', 1024)
assert info['extension'] == ''

270
test_llm_client.py Normal file
View File

@@ -0,0 +1,270 @@
"""
Unit tests for llm_client.py
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
from llm_client import (
OllamaClient,
OpenAIClient,
AnthropicClient,
create_llm_client
)
class TestOllamaClient:
"""Test OllamaClient class."""
def setup_method(self):
"""Set up test fixtures."""
self.client = OllamaClient(
base_url="http://localhost:11434",
model="llama3.2",
timeout=30
)
def test_initialization(self):
"""Test client initialization."""
assert self.client.base_url == "http://localhost:11434"
assert self.client.model == "llama3.2"
assert self.client.timeout == 30
def test_initialization_strips_trailing_slash(self):
"""Test base_url trailing slash is removed."""
client = OllamaClient(base_url="http://localhost:11434/")
assert client.base_url == "http://localhost:11434"
@patch('llm_client.requests.post')
def test_chat_without_context(self, mock_post):
"""Test chat method without context."""
mock_response = Mock()
mock_response.json.return_value = {
"message": {"content": "Hello! How can I help?"}
}
mock_post.return_value = mock_response
messages = [{"role": "user", "content": "Hi"}]
response = self.client.chat(messages)
assert response == "Hello! How can I help?"
mock_post.assert_called_once()
@patch('llm_client.requests.post')
def test_chat_with_context(self, mock_post):
"""Test chat method with context."""
mock_response = Mock()
mock_response.json.return_value = {
"message": {"content": "Based on the document..."}
}
mock_post.return_value = mock_response
messages = [{"role": "user", "content": "What does it say?"}]
context = "This is a test document."
response = self.client.chat(messages, context=context)
assert response == "Based on the document..."
# Verify system message with context was added
call_args = mock_post.call_args
sent_messages = call_args[1]['json']['messages']
assert sent_messages[0]['role'] == 'system'
assert 'test document' in sent_messages[0]['content']
@patch('llm_client.requests.post')
def test_chat_handles_api_error(self, mock_post):
"""Test chat handles API errors."""
import requests
mock_post.side_effect = requests.exceptions.RequestException("Connection failed")
messages = [{"role": "user", "content": "Hi"}]
with pytest.raises(Exception, match="Ollama API error"):
self.client.chat(messages)
@patch('llm_client.requests.get')
def test_is_available_success(self, mock_get):
"""Test is_available returns True when service is up."""
mock_response = Mock()
mock_response.status_code = 200
mock_get.return_value = mock_response
assert self.client.is_available() is True
@patch('llm_client.requests.get')
def test_is_available_failure(self, mock_get):
"""Test is_available returns False when service is down."""
mock_get.side_effect = Exception("Connection refused")
assert self.client.is_available() is False
class TestOpenAIClient:
"""Test OpenAIClient class."""
def setup_method(self):
"""Set up test fixtures."""
self.client = OpenAIClient(
api_key="sk-test123",
model="gpt-4",
timeout=30
)
def test_initialization(self):
"""Test client initialization."""
assert self.client.api_key == "sk-test123"
assert self.client.model == "gpt-4"
assert self.client.timeout == 30
@patch('llm_client.requests.post')
def test_chat_without_context(self, mock_post):
"""Test chat method without context."""
mock_response = Mock()
mock_response.json.return_value = {
"choices": [
{"message": {"content": "Hello from GPT-4!"}}
]
}
mock_post.return_value = mock_response
messages = [{"role": "user", "content": "Hi"}]
response = self.client.chat(messages)
assert response == "Hello from GPT-4!"
@patch('llm_client.requests.post')
def test_chat_with_context(self, mock_post):
"""Test chat method with context."""
mock_response = Mock()
mock_response.json.return_value = {
"choices": [
{"message": {"content": "Based on the context..."}}
]
}
mock_post.return_value = mock_response
messages = [{"role": "user", "content": "Summarize"}]
context = "Test document content"
response = self.client.chat(messages, context=context)
assert response == "Based on the context..."
@patch('llm_client.requests.get')
def test_is_available_success(self, mock_get):
"""Test is_available returns True with valid API key."""
mock_response = Mock()
mock_response.status_code = 200
mock_get.return_value = mock_response
assert self.client.is_available() is True
@patch('llm_client.requests.get')
def test_is_available_failure(self, mock_get):
"""Test is_available returns False with invalid API key."""
mock_get.side_effect = Exception("Unauthorized")
assert self.client.is_available() is False
class TestAnthropicClient:
"""Test AnthropicClient class."""
def setup_method(self):
"""Set up test fixtures."""
self.client = AnthropicClient(
api_key="sk-ant-test123",
model="claude-3-5-sonnet-20241022",
timeout=30
)
def test_initialization(self):
"""Test client initialization."""
assert self.client.api_key == "sk-ant-test123"
assert self.client.model == "claude-3-5-sonnet-20241022"
assert self.client.timeout == 30
@patch('llm_client.requests.post')
def test_chat_without_context(self, mock_post):
"""Test chat method without context."""
mock_response = Mock()
mock_response.json.return_value = {
"content": [{"text": "Hello from Claude!"}]
}
mock_post.return_value = mock_response
messages = [{"role": "user", "content": "Hi"}]
response = self.client.chat(messages)
assert response == "Hello from Claude!"
@patch('llm_client.requests.post')
def test_chat_with_context(self, mock_post):
"""Test chat method with context."""
mock_response = Mock()
mock_response.json.return_value = {
"content": [{"text": "Based on the document..."}]
}
mock_post.return_value = mock_response
messages = [{"role": "user", "content": "Summarize"}]
context = "Test document"
response = self.client.chat(messages, context=context)
assert response == "Based on the document..."
# Verify system prompt was added
call_args = mock_post.call_args
sent_payload = call_args[1]['json']
assert 'system' in sent_payload
assert 'Test document' in sent_payload['system']
def test_is_available_with_api_key(self):
"""Test is_available returns True when API key exists."""
assert self.client.is_available() is True
def test_is_available_without_api_key(self):
"""Test is_available returns False without API key."""
client = AnthropicClient(api_key="")
assert client.is_available() is False
class TestCreateLLMClient:
"""Test create_llm_client factory function."""
def test_create_ollama_client(self):
"""Test creating Ollama client."""
client = create_llm_client("ollama", model="llama3.2")
assert isinstance(client, OllamaClient)
assert client.model == "llama3.2"
def test_create_openai_client(self):
"""Test creating OpenAI client."""
client = create_llm_client("openai", api_key="sk-test", model="gpt-4")
assert isinstance(client, OpenAIClient)
assert client.model == "gpt-4"
def test_create_openai_client_without_key(self):
"""Test creating OpenAI client without API key raises error."""
with pytest.raises(ValueError, match="OpenAI API key required"):
create_llm_client("openai")
def test_create_anthropic_client(self):
"""Test creating Anthropic client."""
client = create_llm_client("anthropic", api_key="sk-ant-test")
assert isinstance(client, AnthropicClient)
def test_create_anthropic_client_without_key(self):
"""Test creating Anthropic client without API key raises error."""
with pytest.raises(ValueError, match="Anthropic API key required"):
create_llm_client("anthropic")
def test_create_unknown_provider(self):
"""Test creating client with unknown provider raises error."""
with pytest.raises(ValueError, match="Unknown LLM provider"):
create_llm_client("unknown_provider")
@patch.dict('os.environ', {'OLLAMA_URL': 'http://custom:11434', 'OLLAMA_MODEL': 'mistral'})
def test_create_ollama_with_env_vars(self):
"""Test creating Ollama client with environment variables."""
client = create_llm_client("ollama")
assert isinstance(client, OllamaClient)
assert client.base_url == "http://custom:11434"
assert client.model == "mistral"

71
toothfairy_client.py Normal file
View File

@@ -0,0 +1,71 @@
import os
import requests
class ToothFairyClient:
def __init__(self):
self.api_key = os.getenv('TOOTHFAIRYAI_API_KEY')
self.workspace_id = os.getenv('TOOTHFAIRYAI_WORKSPACE_ID')
self.base_url = os.getenv('TOOTHFAIRYAI_API_URL', 'https://api.toothfairyai.com')
def _get_headers(self):
return {
"x-api-key": self.api_key,
"Content-Type": "application/json"
}
def create_search_tool(self, ngrok_url):
"""Creates the API Tool that tells the agent how to search your DB."""
# Corrected endpoint based on the documentation
endpoint = f"{self.base_url}/function/create"
payload = {
"workspaceid": self.workspace_id,
"name": "Search_SharePoint_Database",
"url": f"{ngrok_url.rstrip('/')}/api/search/chunks",
"requestType": "POST",
"authorisationType": "none",
"description": "Searches the company's internal SharePoint database for document excerpts. Use this whenever the user asks about internal files, reports, or policies.",
"parameters": [
{
"name": "query",
"type": "string",
"required": True,
"description": "The exact search query to find relevant information."
}
]
}
response = requests.post(endpoint, json=payload, headers=self._get_headers())
response.raise_for_status()
# Adjust parsing based on typical wrapper responses if needed, usually data.id or just id
response_json = response.json()
return response_json.get("data", {}).get("id") or response_json.get("id")
def create_agent(self, tool_id):
"""Creates an Operator agent and attaches your new search tool."""
endpoint = f"{self.base_url}/agent/create"
payload = {
"workspaceid": self.workspace_id,
"label": "SharePoint Assistant",
"mode": "retriever",
"interpolationString": "You are a helpful corporate assistant. Use your Search_SharePoint_Database tool to find answers to user questions. Always cite the 'source_file' in your responses.",
"goals": "Answer questions accurately using internal documentation.",
"temperature": 0.3,
"maxTokens": 2000,
"maxHistory": 10, # <-- REQUIRED FIELD
"topK": 10, # <-- REQUIRED FIELD
"docTopK": 5, # <-- REQUIRED FIELD
"hasFunctions": True,
"agentFunctions": [tool_id],
"agenticRAG": True
}
response = requests.post(endpoint, json=payload, headers=self._get_headers())
if not response.ok:
print(f"❌ Server Error Output: {response.text}")
response.raise_for_status()
response_json = response.json()
return response_json.get("data", {}).get("id") or response_json.get("id")

477
vector_store.py Normal file
View File

@@ -0,0 +1,477 @@
"""
Vector Store for Multi-Document Chat with RAG (Retrieval-Augmented Generation)
Supports document embeddings, tagging, and semantic search for chatting with
multiple documents at once.
"""
import os
import json
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import hashlib
@dataclass
class DocumentChunk:
"""A chunk of a document with metadata."""
chunk_id: str
document_id: str
content: str
chunk_index: int
embedding: Optional[List[float]] = None
metadata: Optional[Dict[str, Any]] = None
@dataclass
class Document:
"""Document metadata with tags."""
document_id: str
user_id: str
site_id: str
file_path: str
filename: str
tags: List[str]
content_hash: str
created_at: str
updated_at: str
chunk_count: int
metadata: Optional[Dict[str, Any]] = None
class EmbeddingProvider:
"""Base class for embedding providers."""
def embed_text(self, text: str) -> List[float]:
"""Generate embedding for text."""
raise NotImplementedError
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts."""
return [self.embed_text(text) for text in texts]
class OllamaEmbeddings(EmbeddingProvider):
"""Ollama embeddings using local models."""
def __init__(self, base_url: str = "http://localhost:11434", model: str = "nomic-embed-text"):
"""
Initialize Ollama embeddings.
Args:
base_url: Ollama server URL
model: Embedding model (e.g., nomic-embed-text, mxbai-embed-large)
"""
import requests
self.base_url = base_url.rstrip('/')
self.model = model
self.requests = requests
def embed_text(self, text: str) -> List[float]:
"""Generate embedding using Ollama."""
response = self.requests.post(
f"{self.base_url}/api/embeddings",
json={
"model": self.model,
"prompt": text
}
)
response.raise_for_status()
return response.json()["embedding"]
class OpenAIEmbeddings(EmbeddingProvider):
"""OpenAI embeddings."""
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
"""
Initialize OpenAI embeddings.
Args:
api_key: OpenAI API key
model: Embedding model (text-embedding-3-small or text-embedding-3-large)
"""
import requests
self.api_key = api_key
self.model = model
self.requests = requests
def embed_text(self, text: str) -> List[float]:
"""Generate embedding using OpenAI."""
response = self.requests.post(
"https://api.openai.com/v1/embeddings",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"input": text,
"model": self.model
}
)
response.raise_for_status()
return response.json()["data"][0]["embedding"]
class InMemoryVectorStore:
"""
Simple in-memory vector store for development.
For production, use a proper vector database like:
- Pinecone
- Weaviate
- Qdrant
- ChromaDB
- pgvector (PostgreSQL extension)
"""
def __init__(self, embedding_provider: EmbeddingProvider):
"""Initialize vector store."""
self.embedding_provider = embedding_provider
self.documents: Dict[str, Document] = {}
self.chunks: Dict[str, DocumentChunk] = {}
self.user_documents: Dict[str, List[str]] = {} # user_id -> [document_ids]
self.tag_index: Dict[str, List[str]] = {} # tag -> [document_ids]
def add_document(
self,
user_id: str,
site_id: str,
file_path: str,
filename: str,
content: str,
tags: List[str],
chunk_size: int = 1000,
chunk_overlap: int = 200,
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""
Add a document to the vector store with chunking and embeddings.
Args:
user_id: User ID
site_id: SharePoint site ID
file_path: File path
filename: Filename
content: Document content
tags: List of tags (e.g., ["HR", "SALES", "Q4-2024"])
chunk_size: Size of each chunk in characters
chunk_overlap: Overlap between chunks
metadata: Additional metadata
Returns:
document_id
"""
# Generate document ID
document_id = self._generate_document_id(user_id, site_id, file_path)
# Calculate content hash
content_hash = hashlib.sha256(content.encode()).hexdigest()
# Check if document already exists and hasn't changed
if document_id in self.documents:
existing_doc = self.documents[document_id]
if existing_doc.content_hash == content_hash:
# Update tags if different
if set(tags) != set(existing_doc.tags):
self._update_tags(document_id, existing_doc.tags, tags)
existing_doc.tags = tags
existing_doc.updated_at = datetime.utcnow().isoformat()
return document_id
else:
# Content changed, remove old chunks
self._remove_document_chunks(document_id)
# Chunk the document
chunks = self._chunk_text(content, chunk_size, chunk_overlap)
# Generate embeddings for all chunks
chunk_texts = [chunk for chunk in chunks]
embeddings = self.embedding_provider.embed_batch(chunk_texts)
# Create document chunks
document_chunks = []
for idx, (chunk_text, embedding) in enumerate(zip(chunk_texts, embeddings)):
chunk_id = f"{document_id}_chunk_{idx}"
chunk = DocumentChunk(
chunk_id=chunk_id,
document_id=document_id,
content=chunk_text,
chunk_index=idx,
embedding=embedding,
metadata=metadata
)
self.chunks[chunk_id] = chunk
document_chunks.append(chunk)
# Create document metadata
now = datetime.utcnow().isoformat()
document = Document(
document_id=document_id,
user_id=user_id,
site_id=site_id,
file_path=file_path,
filename=filename,
tags=tags,
content_hash=content_hash,
created_at=now,
updated_at=now,
chunk_count=len(document_chunks),
metadata=metadata
)
# Store document
self.documents[document_id] = document
# Update user index
if user_id not in self.user_documents:
self.user_documents[user_id] = []
if document_id not in self.user_documents[user_id]:
self.user_documents[user_id].append(document_id)
# Update tag index
for tag in tags:
tag_lower = tag.lower()
if tag_lower not in self.tag_index:
self.tag_index[tag_lower] = []
if document_id not in self.tag_index[tag_lower]:
self.tag_index[tag_lower].append(document_id)
return document_id
def search(
self,
user_id: str,
query: str,
tags: Optional[List[str]] = None,
top_k: int = 5
) -> List[Dict[str, Any]]:
"""
Search for relevant document chunks.
Args:
user_id: User ID (for access control)
query: Search query
tags: Optional list of tags to filter by
top_k: Number of results to return
Returns:
List of relevant chunks with similarity scores
"""
# Get query embedding
query_embedding = self.embedding_provider.embed_text(query)
# Get candidate document IDs
if tags:
# Filter by tags
candidate_doc_ids = set()
for tag in tags:
tag_lower = tag.lower()
if tag_lower in self.tag_index:
candidate_doc_ids.update(self.tag_index[tag_lower])
# Intersect with user's documents
user_doc_ids = set(self.user_documents.get(user_id, []))
candidate_doc_ids = candidate_doc_ids.intersection(user_doc_ids)
else:
# All user's documents
candidate_doc_ids = set(self.user_documents.get(user_id, []))
# Get chunks from candidate documents
candidate_chunks = [
chunk for chunk in self.chunks.values()
if chunk.document_id in candidate_doc_ids
]
# Calculate cosine similarity for each chunk
results = []
for chunk in candidate_chunks:
if chunk.embedding:
similarity = self._cosine_similarity(query_embedding, chunk.embedding)
results.append({
"chunk": chunk,
"document": self.documents[chunk.document_id],
"similarity": similarity
})
# Sort by similarity and return top_k
results.sort(key=lambda x: x["similarity"], reverse=True)
return results[:top_k]
def get_documents_by_tags(self, user_id: str, tags: List[str]) -> List[Document]:
"""Get all documents with specific tags."""
doc_ids = set()
for tag in tags:
tag_lower = tag.lower()
if tag_lower in self.tag_index:
doc_ids.update(self.tag_index[tag_lower])
# Filter by user
user_doc_ids = set(self.user_documents.get(user_id, []))
doc_ids = doc_ids.intersection(user_doc_ids)
return [self.documents[doc_id] for doc_id in doc_ids if doc_id in self.documents]
def list_tags(self, user_id: str) -> Dict[str, int]:
"""List all tags for user with document counts."""
user_doc_ids = set(self.user_documents.get(user_id, []))
tag_counts = {}
for tag, doc_ids in self.tag_index.items():
user_tagged_docs = set(doc_ids).intersection(user_doc_ids)
if user_tagged_docs:
tag_counts[tag] = len(user_tagged_docs)
return tag_counts
def update_document_tags(self, document_id: str, user_id: str, tags: List[str]):
"""Update tags for a document."""
if document_id not in self.documents:
raise ValueError("Document not found")
doc = self.documents[document_id]
if doc.user_id != user_id:
raise ValueError("Access denied")
old_tags = doc.tags
self._update_tags(document_id, old_tags, tags)
doc.tags = tags
doc.updated_at = datetime.utcnow().isoformat()
def remove_document(self, document_id: str, user_id: str):
"""Remove a document and its chunks."""
if document_id not in self.documents:
return
doc = self.documents[document_id]
if doc.user_id != user_id:
raise ValueError("Access denied")
# Remove from tag index
for tag in doc.tags:
tag_lower = tag.lower()
if tag_lower in self.tag_index:
self.tag_index[tag_lower] = [
d for d in self.tag_index[tag_lower] if d != document_id
]
# Remove chunks
self._remove_document_chunks(document_id)
# Remove from user index
if user_id in self.user_documents:
self.user_documents[user_id] = [
d for d in self.user_documents[user_id] if d != document_id
]
# Remove document
del self.documents[document_id]
def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
"""Split text into overlapping chunks."""
if not text:
return []
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = start + chunk_size
# Try to break at sentence boundary
if end < text_length:
# Look for sentence endings
for punct in ['. ', '! ', '? ', '\n\n']:
last_punct = text.rfind(punct, start, end)
if last_punct != -1:
end = last_punct + len(punct)
break
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
start = end - overlap if end < text_length else text_length
return chunks
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
"""Calculate cosine similarity between two vectors."""
import math
dot_product = sum(a * b for a, b in zip(vec1, vec2))
magnitude1 = math.sqrt(sum(a * a for a in vec1))
magnitude2 = math.sqrt(sum(b * b for b in vec2))
if magnitude1 == 0 or magnitude2 == 0:
return 0.0
return dot_product / (magnitude1 * magnitude2)
def _generate_document_id(self, user_id: str, site_id: str, file_path: str) -> str:
"""Generate unique document ID."""
combined = f"{user_id}:{site_id}:{file_path}"
return hashlib.sha256(combined.encode()).hexdigest()
def _remove_document_chunks(self, document_id: str):
"""Remove all chunks for a document."""
chunk_ids_to_remove = [
chunk_id for chunk_id, chunk in self.chunks.items()
if chunk.document_id == document_id
]
for chunk_id in chunk_ids_to_remove:
del self.chunks[chunk_id]
def _update_tags(self, document_id: str, old_tags: List[str], new_tags: List[str]):
"""Update tag index when tags change."""
# Remove from old tags
for tag in old_tags:
tag_lower = tag.lower()
if tag_lower in self.tag_index:
self.tag_index[tag_lower] = [
d for d in self.tag_index[tag_lower] if d != document_id
]
# Add to new tags
for tag in new_tags:
tag_lower = tag.lower()
if tag_lower not in self.tag_index:
self.tag_index[tag_lower] = []
if document_id not in self.tag_index[tag_lower]:
self.tag_index[tag_lower].append(document_id)
def create_embedding_provider(provider: str = "ollama", **kwargs) -> EmbeddingProvider:
"""
Factory function to create embedding provider.
Args:
provider: Provider name ("ollama" or "openai")
**kwargs: Provider-specific configuration
Returns:
EmbeddingProvider instance
Example:
# Ollama (default)
embeddings = create_embedding_provider("ollama", model="nomic-embed-text")
# OpenAI
embeddings = create_embedding_provider("openai", api_key="sk-...")
"""
if provider.lower() == "ollama":
return OllamaEmbeddings(
base_url=kwargs.get("base_url", os.getenv("OLLAMA_URL", "http://localhost:11434")),
model=kwargs.get("model", os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text"))
)
elif provider.lower() == "openai":
api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY"))
if not api_key:
raise ValueError("OpenAI API key required")
return OpenAIEmbeddings(
api_key=api_key,
model=kwargs.get("model", "text-embedding-3-small")
)
else:
raise ValueError(f"Unknown embedding provider: {provider}")

652
vector_store_postgres.py Normal file
View File

@@ -0,0 +1,652 @@
"""
PostgreSQL Vector Store with pgvector for Multi-Document Chat with RAG
Persistent vector storage using PostgreSQL with pgvector extension.
"""
import os
import hashlib
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
import psycopg2
from psycopg2.extras import RealDictCursor, execute_values
from pgvector.psycopg2 import register_vector
@dataclass
class DocumentChunk:
"""A chunk of a document with metadata."""
chunk_id: str
document_id: str
content: str
chunk_index: int
embedding: Optional[List[float]] = None
metadata: Optional[Dict[str, Any]] = None
@dataclass
class Document:
"""Document metadata with tags."""
document_id: str
user_id: str
site_id: str
file_path: str
filename: str
tags: List[str]
content_hash: str
created_at: str
updated_at: str
chunk_count: int
metadata: Optional[Dict[str, Any]] = None
class EmbeddingProvider:
"""Base class for embedding providers."""
def embed_text(self, text: str) -> List[float]:
"""Generate embedding for text."""
raise NotImplementedError
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts."""
return [self.embed_text(text) for text in texts]
class OllamaEmbeddings(EmbeddingProvider):
"""Ollama embeddings using local models."""
def __init__(self, base_url: str = "http://localhost:11434", model: str = "nomic-embed-text"):
"""
Initialize Ollama embeddings.
Args:
base_url: Ollama server URL
model: Embedding model (e.g., nomic-embed-text, mxbai-embed-large)
"""
import requests
self.base_url = base_url.rstrip('/')
self.model = model
self.requests = requests
def embed_text(self, text: str) -> List[float]:
"""Generate embedding using Ollama."""
response = self.requests.post(
f"{self.base_url}/api/embeddings",
json={
"model": self.model,
"prompt": text
}
)
response.raise_for_status()
return response.json()["embedding"]
class OpenAIEmbeddings(EmbeddingProvider):
"""OpenAI embeddings."""
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
"""
Initialize OpenAI embeddings.
Args:
api_key: OpenAI API key
model: Embedding model (text-embedding-3-small or text-embedding-3-large)
"""
import requests
self.api_key = api_key
self.model = model
self.requests = requests
def embed_text(self, text: str) -> List[float]:
"""Generate embedding using OpenAI."""
response = self.requests.post(
"https://api.openai.com/v1/embeddings",
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json={
"input": text,
"model": self.model
}
)
response.raise_for_status()
return response.json()["data"][0]["embedding"]
class PostgreSQLVectorStore:
"""
PostgreSQL vector store with pgvector extension.
Provides persistent storage for document embeddings with tag-based organization.
"""
def __init__(
self,
embedding_provider: EmbeddingProvider,
connection_string: Optional[str] = None,
table_prefix: str = ""
):
"""
Initialize PostgreSQL vector store.
Args:
embedding_provider: Provider for generating embeddings
connection_string: PostgreSQL connection string (uses env vars if not provided)
table_prefix: Prefix for table names (e.g., "dev_" or "prod_")
"""
self.embedding_provider = embedding_provider
self.table_prefix = table_prefix
# Use connection string or build from env vars
if connection_string:
self.connection_string = connection_string
else:
self.connection_string = (
f"host={os.getenv('POSTGRES_HOST', 'localhost')} "
f"port={os.getenv('POSTGRES_PORT', '5432')} "
f"dbname={os.getenv('POSTGRES_DB', 'sharepoint_vectors')} "
f"user={os.getenv('POSTGRES_USER', 'postgres')} "
f"password={os.getenv('POSTGRES_PASSWORD', 'postgres')}"
)
# Initialize database connection
self.conn = psycopg2.connect(self.connection_string)
self.conn.autocommit = False
# Create tables and enable extension if they don't exist
self._create_tables()
# Register pgvector type (must be after extension is enabled)
register_vector(self.conn)
def _create_tables(self):
"""Create tables for documents, chunks, and tags."""
with self.conn.cursor() as cur:
# Enable pgvector extension
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
# Documents table
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_prefix}documents (
document_id VARCHAR(64) PRIMARY KEY,
user_id VARCHAR(255) NOT NULL,
site_id VARCHAR(255) NOT NULL,
file_path TEXT NOT NULL,
filename VARCHAR(255) NOT NULL,
content_hash VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL,
updated_at TIMESTAMP NOT NULL,
chunk_count INTEGER NOT NULL,
metadata JSONB,
UNIQUE(user_id, site_id, file_path)
)
""")
# Document chunks table with vector embeddings
# Assuming 768 dimensions for nomic-embed-text (adjust based on your model)
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_prefix}document_chunks (
chunk_id VARCHAR(128) PRIMARY KEY,
document_id VARCHAR(64) NOT NULL REFERENCES {self.table_prefix}documents(document_id) ON DELETE CASCADE,
content TEXT NOT NULL,
chunk_index INTEGER NOT NULL,
embedding vector(768),
metadata JSONB,
UNIQUE(document_id, chunk_index)
)
""")
# Tags table (many-to-many with documents)
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_prefix}document_tags (
document_id VARCHAR(64) NOT NULL REFERENCES {self.table_prefix}documents(document_id) ON DELETE CASCADE,
tag VARCHAR(100) NOT NULL,
PRIMARY KEY(document_id, tag)
)
""")
# Create indexes for performance
cur.execute(f"""
CREATE INDEX IF NOT EXISTS idx_documents_user_id
ON {self.table_prefix}documents(user_id)
""")
cur.execute(f"""
CREATE INDEX IF NOT EXISTS idx_document_tags_tag
ON {self.table_prefix}document_tags(tag)
""")
# Create IVFFlat index for vector similarity search (after some data is inserted)
# This is commented out - create it manually after inserting ~1000+ vectors:
# CREATE INDEX ON document_chunks USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
self.conn.commit()
def add_document(
self,
user_id: str,
site_id: str,
file_path: str,
filename: str,
content: str,
tags: List[str],
chunk_size: int = 1000,
chunk_overlap: int = 200,
metadata: Optional[Dict[str, Any]] = None
) -> str:
"""
Add a document to the vector store with chunking and embeddings.
Args:
user_id: User ID
site_id: SharePoint site ID
file_path: File path
filename: Filename
content: Document content
tags: List of tags (e.g., ["HR", "SALES", "Q4-2024"])
chunk_size: Size of each chunk in characters
chunk_overlap: Overlap between chunks
metadata: Additional metadata
Returns:
document_id
"""
# Generate document ID
document_id = self._generate_document_id(user_id, site_id, file_path)
# Calculate content hash
content_hash = hashlib.sha256(content.encode()).hexdigest()
with self.conn.cursor() as cur:
# Check if document already exists
cur.execute(
f"SELECT content_hash FROM {self.table_prefix}documents WHERE document_id = %s",
(document_id,)
)
existing = cur.fetchone()
if existing and existing[0] == content_hash:
# Content hasn't changed, just update tags if different
cur.execute(
f"SELECT tag FROM {self.table_prefix}document_tags WHERE document_id = %s",
(document_id,)
)
existing_tags = [row[0] for row in cur.fetchall()]
if set(tags) != set(existing_tags):
self._update_tags(cur, document_id, tags)
cur.execute(
f"UPDATE {self.table_prefix}documents SET updated_at = %s WHERE document_id = %s",
(datetime.utcnow(), document_id)
)
self.conn.commit()
return document_id
# Delete old chunks if content changed
if existing:
cur.execute(
f"DELETE FROM {self.table_prefix}document_chunks WHERE document_id = %s",
(document_id,)
)
# Chunk the document
chunks = self._chunk_text(content, chunk_size, chunk_overlap)
# Generate embeddings for all chunks
embeddings = self.embedding_provider.embed_batch(chunks)
# Insert or update document
now = datetime.utcnow()
if existing:
cur.execute(f"""
UPDATE {self.table_prefix}documents
SET content_hash = %s, updated_at = %s, chunk_count = %s, metadata = %s
WHERE document_id = %s
""", (content_hash, now, len(chunks), psycopg2.extras.Json(metadata), document_id))
else:
cur.execute(f"""
INSERT INTO {self.table_prefix}documents
(document_id, user_id, site_id, file_path, filename, content_hash, created_at, updated_at, chunk_count, metadata)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (document_id, user_id, site_id, file_path, filename, content_hash, now, now, len(chunks), psycopg2.extras.Json(metadata)))
# Insert chunks with embeddings
chunk_data = [
(
f"{document_id}_chunk_{idx}",
document_id,
chunk_text,
idx,
embedding,
psycopg2.extras.Json(metadata)
)
for idx, (chunk_text, embedding) in enumerate(zip(chunks, embeddings))
]
execute_values(
cur,
f"""
INSERT INTO {self.table_prefix}document_chunks
(chunk_id, document_id, content, chunk_index, embedding, metadata)
VALUES %s
""",
chunk_data
)
# Update tags
self._update_tags(cur, document_id, tags)
self.conn.commit()
return document_id
def search(
self,
user_id: str,
query: str,
tags: Optional[List[str]] = None,
top_k: int = 5
) -> List[Dict[str, Any]]:
"""
Search for relevant document chunks using vector similarity.
Args:
user_id: User ID (for access control)
query: Search query
tags: Optional list of tags to filter by
top_k: Number of results to return
Returns:
List of relevant chunks with similarity scores
"""
# Generate query embedding
query_embedding = self.embedding_provider.embed_text(query)
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
# Build query based on whether tags are specified
if tags:
# Filter by tags using JOIN
tag_placeholders = ','.join(['%s'] * len(tags))
cur.execute(f"""
SELECT
c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata as chunk_metadata,
d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at,
d.chunk_count, d.content_hash, d.metadata as doc_metadata,
ARRAY_AGG(dt.tag) as tags,
1 - (c.embedding <=> %s::vector) as similarity
FROM {self.table_prefix}document_chunks c
JOIN {self.table_prefix}documents d ON c.document_id = d.document_id
JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
WHERE d.user_id = %s AND LOWER(dt.tag) IN ({tag_placeholders})
GROUP BY c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata,
d.user_id, d.site_id, d.file_path, d.filename, d.created_at,
d.updated_at, d.chunk_count, d.content_hash, d.metadata, c.embedding
ORDER BY similarity DESC
LIMIT %s
""", [query_embedding, user_id] + [tag.lower() for tag in tags] + [top_k])
else:
# All user's documents
cur.execute(f"""
SELECT
c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata as chunk_metadata,
d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at,
d.chunk_count, d.content_hash, d.metadata as doc_metadata,
ARRAY_AGG(dt.tag) as tags,
1 - (c.embedding <=> %s::vector) as similarity
FROM {self.table_prefix}document_chunks c
JOIN {self.table_prefix}documents d ON c.document_id = d.document_id
LEFT JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
WHERE d.user_id = %s
GROUP BY c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata,
d.user_id, d.site_id, d.file_path, d.filename, d.created_at,
d.updated_at, d.chunk_count, d.content_hash, d.metadata, c.embedding
ORDER BY similarity DESC
LIMIT %s
""", [query_embedding, user_id, top_k])
rows = cur.fetchall()
# Convert to result format
results = []
for row in rows:
chunk = DocumentChunk(
chunk_id=row['chunk_id'],
document_id=row['document_id'],
content=row['content'],
chunk_index=row['chunk_index'],
metadata=row['chunk_metadata']
)
document = Document(
document_id=row['document_id'],
user_id=row['user_id'],
site_id=row['site_id'],
file_path=row['file_path'],
filename=row['filename'],
tags=row['tags'] or [],
content_hash=row['content_hash'],
created_at=row['created_at'].isoformat(),
updated_at=row['updated_at'].isoformat(),
chunk_count=row['chunk_count'],
metadata=row['doc_metadata']
)
results.append({
"chunk": chunk,
"document": document,
"similarity": float(row['similarity'])
})
return results
def get_documents_by_tags(self, user_id: str, tags: List[str]) -> List[Document]:
"""Get all documents with specific tags."""
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
tag_placeholders = ','.join(['%s'] * len(tags))
cur.execute(f"""
SELECT DISTINCT
d.document_id, d.user_id, d.site_id, d.file_path, d.filename,
d.content_hash, d.created_at, d.updated_at, d.chunk_count, d.metadata,
ARRAY_AGG(dt.tag) as tags
FROM {self.table_prefix}documents d
JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
WHERE d.user_id = %s AND LOWER(dt.tag) IN ({tag_placeholders})
GROUP BY d.document_id, d.user_id, d.site_id, d.file_path, d.filename,
d.content_hash, d.created_at, d.updated_at, d.chunk_count, d.metadata
""", [user_id] + [tag.lower() for tag in tags])
rows = cur.fetchall()
return [
Document(
document_id=row['document_id'],
user_id=row['user_id'],
site_id=row['site_id'],
file_path=row['file_path'],
filename=row['filename'],
tags=row['tags'] or [],
content_hash=row['content_hash'],
created_at=row['created_at'].isoformat(),
updated_at=row['updated_at'].isoformat(),
chunk_count=row['chunk_count'],
metadata=row['metadata']
)
for row in rows
]
def list_tags(self, user_id: str) -> Dict[str, int]:
"""List all tags for user with document counts."""
with self.conn.cursor() as cur:
cur.execute(f"""
SELECT dt.tag, COUNT(DISTINCT dt.document_id) as count
FROM {self.table_prefix}document_tags dt
JOIN {self.table_prefix}documents d ON dt.document_id = d.document_id
WHERE d.user_id = %s
GROUP BY dt.tag
ORDER BY count DESC, dt.tag
""", (user_id,))
return {row[0]: row[1] for row in cur.fetchall()}
def get_indexed_sites(self, user_id: str) -> List[str]:
"""
Get list of site IDs that have indexed documents for this user.
Args:
user_id: User ID
Returns:
List of site IDs with indexed documents
"""
with self.conn.cursor() as cur:
cur.execute(f"""
SELECT DISTINCT site_id
FROM {self.table_prefix}documents
WHERE user_id = %s
ORDER BY site_id
""", (user_id,))
return [row[0] for row in cur.fetchall()]
def update_document_tags(self, document_id: str, user_id: str, tags: List[str]):
"""Update tags for a document."""
with self.conn.cursor() as cur:
# Verify ownership
cur.execute(
f"SELECT user_id FROM {self.table_prefix}documents WHERE document_id = %s",
(document_id,)
)
row = cur.fetchone()
if not row:
raise ValueError("Document not found")
if row[0] != user_id:
raise ValueError("Access denied")
# Update tags
self._update_tags(cur, document_id, tags)
# Update timestamp
cur.execute(
f"UPDATE {self.table_prefix}documents SET updated_at = %s WHERE document_id = %s",
(datetime.utcnow(), document_id)
)
self.conn.commit()
def remove_document(self, document_id: str, user_id: str):
"""Remove a document and its chunks."""
with self.conn.cursor() as cur:
# Verify ownership
cur.execute(
f"SELECT user_id FROM {self.table_prefix}documents WHERE document_id = %s",
(document_id,)
)
row = cur.fetchone()
if not row:
return
if row[0] != user_id:
raise ValueError("Access denied")
# Delete document (cascades to chunks and tags)
cur.execute(
f"DELETE FROM {self.table_prefix}documents WHERE document_id = %s",
(document_id,)
)
self.conn.commit()
def _update_tags(self, cur, document_id: str, tags: List[str]):
"""Update tags for a document within a transaction."""
# Delete old tags
cur.execute(
f"DELETE FROM {self.table_prefix}document_tags WHERE document_id = %s",
(document_id,)
)
# Insert new tags
if tags:
tag_data = [(document_id, tag.lower()) for tag in tags]
execute_values(
cur,
f"INSERT INTO {self.table_prefix}document_tags (document_id, tag) VALUES %s",
tag_data
)
def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
"""Split text into overlapping chunks."""
if not text:
return []
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = start + chunk_size
# Try to break at sentence boundary
if end < text_length:
for punct in ['. ', '! ', '? ', '\n\n']:
last_punct = text.rfind(punct, start, end)
if last_punct != -1:
end = last_punct + len(punct)
break
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
start = end - overlap if end < text_length else text_length
return chunks
def _generate_document_id(self, user_id: str, site_id: str, file_path: str) -> str:
"""Generate unique document ID."""
combined = f"{user_id}:{site_id}:{file_path}"
return hashlib.sha256(combined.encode()).hexdigest()
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def create_embedding_provider(provider: str = "ollama", **kwargs) -> EmbeddingProvider:
"""
Factory function to create embedding provider.
Args:
provider: Provider name ("ollama" or "openai")
**kwargs: Provider-specific configuration
Returns:
EmbeddingProvider instance
"""
if provider.lower() == "ollama":
return OllamaEmbeddings(
base_url=kwargs.get("base_url", os.getenv("OLLAMA_URL", "http://localhost:11434")),
model=kwargs.get("model", os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text"))
)
elif provider.lower() == "openai":
api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY"))
if not api_key:
raise ValueError("OpenAI API key required")
return OpenAIEmbeddings(
api_key=api_key,
model=kwargs.get("model", "text-embedding-3-small")
)
else:
raise ValueError(f"Unknown embedding provider: {provider}")