Initial commit: SharePoint connector and ToothFairyAI integration
Some checks failed
CI - SharePoint Plugin with SonarQube / Test and SonarQube Analysis (push) Has been cancelled
Some checks failed
CI - SharePoint Plugin with SonarQube / Test and SonarQube Analysis (push) Has been cancelled
This commit is contained in:
33
.env.example
Normal file
33
.env.example
Normal file
@@ -0,0 +1,33 @@
|
||||
# SharePoint Connector Configuration
|
||||
|
||||
# Azure App Registration (REQUIRED)
|
||||
SHAREPOINT_CLIENT_ID=your-client-id-here
|
||||
SHAREPOINT_CLIENT_SECRET=your-client-secret-here
|
||||
SHAREPOINT_TENANT_ID=common
|
||||
|
||||
# OAuth Callback (REQUIRED - must match Azure app registration)
|
||||
# Local development:
|
||||
REDIRECT_URI=http://localhost:5000/sharepoint/callback
|
||||
# Production:
|
||||
# REDIRECT_URI=https://yourdomain.com/sharepoint/callback
|
||||
|
||||
# Security (REQUIRED)
|
||||
# Generate with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
||||
ENCRYPTION_KEY=your-generated-encryption-key-here
|
||||
|
||||
# Flask Security (OPTIONAL - auto-generated if not set)
|
||||
# Generate with: python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
FLASK_SECRET_KEY=your-generated-flask-secret-here
|
||||
|
||||
# AWS Configuration (REQUIRED for DynamoDB)
|
||||
AWS_REGION=us-east-1
|
||||
# AWS_ACCESS_KEY_ID=your-key-id # Optional if using IAM role
|
||||
# AWS_SECRET_ACCESS_KEY=your-secret-key # Optional if using IAM role
|
||||
|
||||
# DynamoDB Settings (OPTIONAL)
|
||||
TABLE_PREFIX= # Optional: prefix table names (e.g., "dev_" or "prod_")
|
||||
# DYNAMODB_ENDPOINT=http://localhost:8000 # Uncomment for local DynamoDB
|
||||
|
||||
# Flask Settings (OPTIONAL)
|
||||
FLASK_DEBUG=false # Set to "true" for development
|
||||
PORT=5000
|
||||
96
.gitea/workflows/ci-sonarqube.yml
Normal file
96
.gitea/workflows/ci-sonarqube.yml
Normal file
@@ -0,0 +1,96 @@
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- develop
|
||||
- feature/*
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- develop
|
||||
- feature/*
|
||||
|
||||
name: CI - SharePoint Plugin with SonarQube
|
||||
|
||||
jobs:
|
||||
test-and-scan:
|
||||
name: Test and SonarQube Analysis
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: plugins/sharepoint
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('plugins/sharepoint/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-cov pytest-mock
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: |
|
||||
pytest \
|
||||
--cov=. \
|
||||
--cov-report=xml:coverage.xml \
|
||||
--cov-report=term \
|
||||
--cov-report=html \
|
||||
-v \
|
||||
--ignore=venv \
|
||||
--ignore=.venv
|
||||
continue-on-error: false
|
||||
|
||||
- name: Verify coverage file
|
||||
run: |
|
||||
pwd
|
||||
ls -la
|
||||
if [ -f coverage.xml ]; then
|
||||
echo "✅ Coverage file exists"
|
||||
echo "First 30 lines of coverage.xml:"
|
||||
head -30 coverage.xml
|
||||
else
|
||||
echo "❌ ERROR: Coverage file was not created!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: SonarQube Scan
|
||||
uses: SonarSource/sonarqube-scan-action@v6
|
||||
with:
|
||||
projectBaseDir: plugins/sharepoint
|
||||
args: >
|
||||
-Dsonar.projectKey=sharepoint-connector
|
||||
-Dsonar.projectName=SharePoint Connector Plugin
|
||||
-Dsonar.sources=.
|
||||
-Dsonar.python.coverage.reportPaths=coverage.xml
|
||||
-Dsonar.tests=.
|
||||
-Dsonar.test.inclusions=test_*.py
|
||||
-Dsonar.exclusions=venv/**,**/__pycache__/**,*.pyc,.venv/**,htmlcov/**,templates/**,static/**
|
||||
-Dsonar.python.version=3.11
|
||||
env:
|
||||
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
|
||||
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
|
||||
|
||||
- name: SonarQube Quality Gate Check
|
||||
uses: SonarSource/sonarqube-quality-gate-action@v1
|
||||
timeout-minutes: 5
|
||||
env:
|
||||
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
|
||||
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
|
||||
|
||||
44
.gitignore
vendored
Normal file
44
.gitignore
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
|
||||
# Virtual Environment
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
.venv
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Application data
|
||||
/home/data/
|
||||
|
||||
# Pytest
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
coverage.xml
|
||||
htmlcov/
|
||||
|
||||
# Distribution
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
23
Dockerfile
Normal file
23
Dockerfile
Normal file
@@ -0,0 +1,23 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application files
|
||||
COPY app.py .
|
||||
COPY saas_connector_dynamodb.py .
|
||||
COPY templates/ templates/
|
||||
COPY static/ static/
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=40s --retries=3 \
|
||||
CMD python -c "import requests; requests.get('http://localhost:8000/health')"
|
||||
|
||||
# Run with gunicorn (production-grade WSGI server)
|
||||
CMD ["gunicorn", "--bind", "0.0.0.0:8000", "--workers", "4", "--timeout", "120", "app:app"]
|
||||
67
README.md
Normal file
67
README.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# SharePoint Connector for SaaS Applications
|
||||
|
||||
Enterprise-grade SharePoint integration for SaaS applications. Features secure OAuth 2.0, multi-tenant isolation, automatic text extraction, and AI-powered document chat (RAG) using local or cloud LLMs.
|
||||
|
||||
## Quick Start - Local Development
|
||||
|
||||
### 1. Install Dependencies
|
||||
```bash
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # Windows: .\venv\Scripts\activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
2. Start Required Services (Docker)
|
||||
You need local instances of DynamoDB (for configurations/sessions) and PostgreSQL with pgvector (for document embeddings).
|
||||
|
||||
# Start DynamoDB Local
|
||||
docker run -p 8000:8000 amazon/dynamodb-local
|
||||
|
||||
# Start PostgreSQL with pgvector
|
||||
docker run -d --name postgres-vector -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=sharepoint_vectors -p 5432:5432 pgvector/pgvector:pg16
|
||||
|
||||
3. Start Ollama (For AI & Embeddings)
|
||||
Ensure Ollama is installed (ollama.ai) and pull the necessary models:
|
||||
ollama pull llama3.2
|
||||
ollama pull nomic-embed-text
|
||||
|
||||
4. Configure Environment Variables
|
||||
Create a .env file in the root directory. Here are the core variables you need:
|
||||
# --- CORE APP ---
|
||||
AS BEFORE!
|
||||
|
||||
# --- TOOTHFAIRYAI AGENT INTEGRATION (NEW, ALL FIELDS REQUIRED) ---
|
||||
TOOTHFAIRYAI_API_KEY=your_api_key_here
|
||||
TOOTHFAIRYAI_WORKSPACE_ID=your_workspace_uuid_here
|
||||
TOOTHFAIRYAI_API_URL=[https://api.toothfairyai.com](https://api.toothfairyai.com)
|
||||
NGROK_URL=[https://your-url.ngrok-free.app](https://your-url.ngrok-free.app)
|
||||
|
||||
5. Run the Application
|
||||
python app_dev.py
|
||||
|
||||
Open http://localhost:5001 in your browser. Enter your Azure Client ID, Client Secret, and Tenant ID in the UI to connect your SharePoint account.
|
||||
|
||||
Azure App Registration Setup
|
||||
To connect to SharePoint, you must register an app in the Azure Portal:
|
||||
|
||||
Go to Azure Active Directory -> App registrations -> New registration.
|
||||
|
||||
Set Supported account types to Multi-tenant.
|
||||
|
||||
Set Redirect URI to http://localhost:5001/sharepoint/callback (update for production).
|
||||
|
||||
Save your Application (client) ID and Directory (tenant) ID.
|
||||
|
||||
Under Certificates & secrets, create a new client secret and copy the Value.
|
||||
|
||||
Under API permissions, add Delegated permissions for Microsoft Graph: User.Read, Sites.Read.All, Files.Read.All, offline_access.
|
||||
|
||||
Production Deployment (AWS ECS)
|
||||
For production, the app is designed to run on AWS ECS Fargate.
|
||||
|
||||
Database: Use a managed AWS DynamoDB table and an RDS PostgreSQL instance.
|
||||
|
||||
Environment: Run python app.py instead of app_dev.py.
|
||||
|
||||
Security: Assign an IAM Task Role to the ECS container with strict DynamoDB permissions. Store the FLASK_SECRET_KEY in AWS Secrets Manager.
|
||||
|
||||
Routing: Place the ECS service behind an Application Load Balancer (ALB) configured with HTTPS.
|
||||
1094
app_dev.py
Normal file
1094
app_dev.py
Normal file
File diff suppressed because it is too large
Load Diff
308
background_indexer.py
Normal file
308
background_indexer.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Background Document Indexer
|
||||
|
||||
Indexes SharePoint files in the background without blocking the UI.
|
||||
Uses threading for simple deployment (no Celery/Redis required).
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IndexingJob:
|
||||
"""Represents a background indexing job."""
|
||||
|
||||
def __init__(self, job_id: str, site_id: str, site_name: str, connection_id: str, user_id: str):
|
||||
self.job_id = job_id
|
||||
self.site_id = site_id
|
||||
self.site_name = site_name
|
||||
self.connection_id = connection_id
|
||||
self.user_id = user_id
|
||||
self.status = "pending" # pending, running, completed, failed
|
||||
self.progress = 0 # 0-100
|
||||
self.total_files = 0
|
||||
self.processed_files = 0
|
||||
self.successful_files = 0
|
||||
self.failed_files = 0
|
||||
self.started_at = None
|
||||
self.completed_at = None
|
||||
self.error = None
|
||||
self.current_file = None
|
||||
|
||||
|
||||
class BackgroundIndexer:
|
||||
"""
|
||||
Background indexer for SharePoint documents.
|
||||
|
||||
Runs indexing jobs in separate threads to avoid blocking the main application.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.jobs: Dict[str, IndexingJob] = {}
|
||||
self.active_threads: Dict[str, threading.Thread] = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def start_indexing(
|
||||
self,
|
||||
job_id: str,
|
||||
site_id: str,
|
||||
site_name: str,
|
||||
connection_id: str,
|
||||
user_id: str,
|
||||
connector,
|
||||
vector_store,
|
||||
document_parser,
|
||||
path: str = "",
|
||||
tags: List[str] = None
|
||||
) -> IndexingJob:
|
||||
"""
|
||||
Start a background indexing job for a SharePoint site.
|
||||
|
||||
Args:
|
||||
job_id: Unique job identifier
|
||||
site_id: SharePoint site ID
|
||||
site_name: Site display name
|
||||
connection_id: SharePoint connection ID
|
||||
user_id: User ID
|
||||
connector: SharePoint connector instance
|
||||
vector_store: Vector store instance
|
||||
document_parser: Document parser instance
|
||||
path: Optional path to start indexing from
|
||||
|
||||
Returns:
|
||||
IndexingJob instance
|
||||
"""
|
||||
with self.lock:
|
||||
# Create job
|
||||
job = IndexingJob(job_id, site_id, site_name, connection_id, user_id)
|
||||
self.jobs[job_id] = job
|
||||
|
||||
# Start thread
|
||||
thread = threading.Thread(
|
||||
target=self._index_site,
|
||||
args=(job, connector, vector_store, document_parser, path, tags or []),
|
||||
daemon=True
|
||||
)
|
||||
self.active_threads[job_id] = thread
|
||||
thread.start()
|
||||
|
||||
logger.info(f"Started indexing job {job_id} for site {site_name}")
|
||||
|
||||
return job
|
||||
|
||||
def get_job_status(self, job_id: str) -> Optional[Dict]:
|
||||
"""Get the status of an indexing job."""
|
||||
with self.lock:
|
||||
job = self.jobs.get(job_id)
|
||||
if not job:
|
||||
return None
|
||||
|
||||
return {
|
||||
"job_id": job.job_id,
|
||||
"site_id": job.site_id,
|
||||
"site_name": job.site_name,
|
||||
"status": job.status,
|
||||
"progress": job.progress,
|
||||
"total_files": job.total_files,
|
||||
"processed_files": job.processed_files,
|
||||
"successful_files": job.successful_files,
|
||||
"failed_files": job.failed_files,
|
||||
"current_file": job.current_file,
|
||||
"started_at": job.started_at.isoformat() if job.started_at else None,
|
||||
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
|
||||
"error": job.error
|
||||
}
|
||||
|
||||
def cancel_job(self, job_id: str) -> bool:
|
||||
"""Cancel a running indexing job."""
|
||||
with self.lock:
|
||||
job = self.jobs.get(job_id)
|
||||
if job and job.status == "running":
|
||||
job.status = "cancelled"
|
||||
job.completed_at = datetime.utcnow()
|
||||
logger.info(f"Cancelled indexing job {job_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _index_site(self, job, connector, vector_store, document_parser, path="", tags=None):
|
||||
"""
|
||||
Index all files in a SharePoint site (runs in background thread).
|
||||
|
||||
This method:
|
||||
1. Recursively lists all files in the site
|
||||
2. Downloads and parses each file
|
||||
3. Generates embeddings
|
||||
4. Stores in vector store with specified tags
|
||||
"""
|
||||
from saas_connector_dynamodb import SecureSharePointClient
|
||||
|
||||
if tags is None:
|
||||
tags = []
|
||||
|
||||
try:
|
||||
job.status = "running"
|
||||
job.started_at = datetime.utcnow()
|
||||
|
||||
# Create SharePoint client
|
||||
client = SecureSharePointClient(connector, job.connection_id, job.user_id)
|
||||
|
||||
# First, count total files
|
||||
logger.info(f"[Job {job.job_id}] Counting files in {job.site_name}...")
|
||||
all_files = self._list_all_files_recursive(client, job, job.site_id, path)
|
||||
|
||||
if job.status == "cancelled":
|
||||
return
|
||||
|
||||
job.total_files = len(all_files)
|
||||
logger.info(f"[Job {job.job_id}] Found {job.total_files} files to index")
|
||||
|
||||
# Process each file
|
||||
for file_info in all_files:
|
||||
if job.status == "cancelled":
|
||||
logger.info(f"[Job {job.job_id}] Job cancelled by user")
|
||||
break
|
||||
|
||||
try:
|
||||
self._process_file(
|
||||
job,
|
||||
client,
|
||||
vector_store,
|
||||
document_parser,
|
||||
file_info,
|
||||
tags
|
||||
)
|
||||
job.successful_files += 1
|
||||
except Exception as e:
|
||||
logger.error(f"[Job {job.job_id}] Failed to process {file_info['path']}: {e}")
|
||||
job.failed_files += 1
|
||||
|
||||
job.processed_files += 1
|
||||
job.progress = int((job.processed_files / job.total_files) * 100) if job.total_files > 0 else 0
|
||||
|
||||
# Mark as completed
|
||||
if job.status != "cancelled":
|
||||
job.status = "completed"
|
||||
job.completed_at = datetime.utcnow()
|
||||
logger.info(
|
||||
f"[Job {job.job_id}] Completed: {job.successful_files} successful, "
|
||||
f"{job.failed_files} failed out of {job.total_files} total"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Job {job.job_id}] Job failed with error: {e}")
|
||||
job.status = "failed"
|
||||
job.error = str(e)
|
||||
job.completed_at = datetime.utcnow()
|
||||
|
||||
def _list_all_files_recursive(self, client, job, site_id, path=""):
|
||||
"""Recursively list all files in a site."""
|
||||
files_to_process = []
|
||||
|
||||
try:
|
||||
items = client.list_files(site_id, path)
|
||||
|
||||
for item in items:
|
||||
if job.status == "cancelled":
|
||||
break
|
||||
|
||||
# If it's a folder, recurse
|
||||
if 'folder' in item:
|
||||
folder_name = item['name']
|
||||
new_path = f"{path}/{folder_name}" if path else folder_name
|
||||
files_to_process.extend(
|
||||
self._list_all_files_recursive(client, job, site_id, new_path)
|
||||
)
|
||||
else:
|
||||
# It's a file
|
||||
file_path = f"{path}/{item['name']}" if path else item['name']
|
||||
files_to_process.append({
|
||||
'name': item['name'],
|
||||
'path': file_path,
|
||||
'size': item.get('size', 0)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Job {job.job_id}] Error listing files in {path}: {e}")
|
||||
|
||||
return files_to_process
|
||||
|
||||
def _process_file(self, job, client, vector_store, document_parser, file_info, tags=None):
|
||||
"""Process a single file: download, parse, embed, store."""
|
||||
if not vector_store:
|
||||
logger.warning(f"[Job {job.job_id}] Vector store not available, skipping {file_info['name']}")
|
||||
return
|
||||
|
||||
if tags is None:
|
||||
tags = []
|
||||
|
||||
filename = file_info['name']
|
||||
file_path = file_info['path']
|
||||
|
||||
job.current_file = filename
|
||||
logger.info(f"[Job {job.job_id}] Processing {filename}...")
|
||||
|
||||
# Check if file can be parsed
|
||||
if not document_parser.can_parse(filename):
|
||||
logger.info(f"[Job {job.job_id}] Skipping unsupported file type: {filename}")
|
||||
return
|
||||
|
||||
# Download file
|
||||
binary_content = client.read_file(job.site_id, file_path, as_text=False)
|
||||
|
||||
# Parse content
|
||||
try:
|
||||
content = document_parser.parse(binary_content, filename)
|
||||
except Exception as parse_err:
|
||||
logger.warning(f"[Job {job.job_id}] Failed to parse {filename}: {parse_err}")
|
||||
# Try fallback to text
|
||||
try:
|
||||
content = binary_content.decode('utf-8', errors='ignore')
|
||||
except:
|
||||
logger.error(f"[Job {job.job_id}] Could not decode {filename}, skipping")
|
||||
return
|
||||
|
||||
# Skip error messages
|
||||
if content.startswith("[") and content.endswith("]"):
|
||||
logger.info(f"[Job {job.job_id}] Skipping error content for {filename}")
|
||||
return
|
||||
|
||||
# Skip very small files (likely empty)
|
||||
if len(content.strip()) < 50:
|
||||
logger.info(f"[Job {job.job_id}] Skipping small file {filename} ({len(content)} chars)")
|
||||
return
|
||||
|
||||
# Add to vector store with specified tags
|
||||
try:
|
||||
document_id = vector_store.add_document(
|
||||
user_id=job.user_id,
|
||||
site_id=job.site_id,
|
||||
file_path=file_path,
|
||||
filename=filename,
|
||||
content=content,
|
||||
tags=tags, # Apply tags from indexing request
|
||||
chunk_size=1000,
|
||||
chunk_overlap=200
|
||||
)
|
||||
logger.info(f"[Job {job.job_id}] Indexed {filename} with document_id {document_id} and tags {tags}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Job {job.job_id}] Failed to add {filename} to vector store: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Global indexer instance
|
||||
_indexer = None
|
||||
|
||||
|
||||
def get_indexer() -> BackgroundIndexer:
|
||||
"""Get the global indexer instance."""
|
||||
global _indexer
|
||||
if _indexer is None:
|
||||
_indexer = BackgroundIndexer()
|
||||
return _indexer
|
||||
196
clear_data.py
Normal file
196
clear_data.py
Normal file
@@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Clear all data from the SharePoint connector application.
|
||||
This script will:
|
||||
1. Drop all vector store tables (PostgreSQL)
|
||||
2. Delete all DynamoDB tables (connections, tokens, etc.)
|
||||
3. Clear stored credentials
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import boto3
|
||||
from storage.credentials_storage import get_credentials_storage
|
||||
|
||||
def clear_postgresql_vector_store():
|
||||
"""Drop all vector store tables from PostgreSQL."""
|
||||
try:
|
||||
import psycopg2
|
||||
|
||||
# Get connection params from environment
|
||||
db_host = os.getenv("POSTGRES_HOST", "localhost")
|
||||
db_port = os.getenv("POSTGRES_PORT", "5432")
|
||||
db_name = os.getenv("POSTGRES_DB", "sharepoint_vectors")
|
||||
db_user = os.getenv("POSTGRES_USER", "postgres")
|
||||
db_password = os.getenv("POSTGRES_PASSWORD", "postgres")
|
||||
|
||||
print(f"📊 Connecting to PostgreSQL at {db_host}:{db_port}/{db_name}...")
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
database=db_name,
|
||||
user=db_user,
|
||||
password=db_password
|
||||
)
|
||||
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get table prefix
|
||||
table_prefix = os.getenv("TABLE_PREFIX", "dev_")
|
||||
|
||||
# Drop tables in correct order (child tables first)
|
||||
tables = [
|
||||
f"{table_prefix}document_tags",
|
||||
f"{table_prefix}document_chunks",
|
||||
f"{table_prefix}documents"
|
||||
]
|
||||
|
||||
for table in tables:
|
||||
try:
|
||||
cursor.execute(f"DROP TABLE IF EXISTS {table} CASCADE")
|
||||
print(f" ✅ Dropped table: {table}")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Could not drop {table}: {e}")
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
print("✅ PostgreSQL vector store cleared")
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
print("⚠️ psycopg2 not installed, skipping PostgreSQL cleanup")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Error clearing PostgreSQL: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def clear_dynamodb_tables():
|
||||
"""Delete all DynamoDB tables."""
|
||||
try:
|
||||
# Determine if using local or AWS DynamoDB
|
||||
endpoint_url = os.getenv("DYNAMODB_ENDPOINT", "http://localhost:8000")
|
||||
region = os.getenv("AWS_REGION", "ap-southeast-2")
|
||||
table_prefix = os.getenv("TABLE_PREFIX", "dev_")
|
||||
|
||||
if "localhost" in endpoint_url or "127.0.0.1" in endpoint_url:
|
||||
print(f"📊 Connecting to LOCAL DynamoDB at {endpoint_url}...")
|
||||
else:
|
||||
print(f"📊 Connecting to AWS DynamoDB in {region}...")
|
||||
|
||||
dynamodb = boto3.client(
|
||||
'dynamodb',
|
||||
region_name=region,
|
||||
endpoint_url=endpoint_url if "localhost" in endpoint_url or "127.0.0.1" in endpoint_url else None
|
||||
)
|
||||
|
||||
# List all tables
|
||||
response = dynamodb.list_tables()
|
||||
tables = response.get('TableNames', [])
|
||||
|
||||
# Filter tables with our prefix
|
||||
our_tables = [t for t in tables if t.startswith(table_prefix)]
|
||||
|
||||
if not our_tables:
|
||||
print(f" ℹ️ No tables found with prefix '{table_prefix}'")
|
||||
return True
|
||||
|
||||
print(f" Found {len(our_tables)} tables to delete:")
|
||||
for table in our_tables:
|
||||
print(f" - {table}")
|
||||
|
||||
# Delete each table
|
||||
for table in our_tables:
|
||||
try:
|
||||
dynamodb.delete_table(TableName=table)
|
||||
print(f" ✅ Deleted table: {table}")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Could not delete {table}: {e}")
|
||||
|
||||
print("✅ DynamoDB tables cleared")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error clearing DynamoDB: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def clear_credentials():
|
||||
"""Clear all stored credentials."""
|
||||
try:
|
||||
print("📊 Clearing stored credentials...")
|
||||
|
||||
storage = get_credentials_storage()
|
||||
|
||||
# Get the storage file path
|
||||
if hasattr(storage, 'file_path'):
|
||||
file_path = storage.file_path
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
print(f" ✅ Deleted credentials file: {file_path}")
|
||||
else:
|
||||
print(f" ℹ️ No credentials file found at: {file_path}")
|
||||
else:
|
||||
print(" ℹ️ Credentials storage does not use file system")
|
||||
|
||||
print("✅ Credentials cleared")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error clearing credentials: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to clear all data."""
|
||||
print("=" * 60)
|
||||
print("🧹 CLEARING ALL SHAREPOINT CONNECTOR DATA")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Check for --force flag
|
||||
force = '--force' in sys.argv or '-f' in sys.argv
|
||||
|
||||
if not force:
|
||||
try:
|
||||
response = input("⚠️ This will DELETE ALL data. Are you sure? (yes/no): ")
|
||||
if response.lower() not in ['yes', 'y']:
|
||||
print("❌ Aborted")
|
||||
return
|
||||
except EOFError:
|
||||
print("\n❌ No input received. Use --force to skip confirmation.")
|
||||
return
|
||||
|
||||
print()
|
||||
print("Starting cleanup...")
|
||||
print()
|
||||
|
||||
# Clear PostgreSQL vector store
|
||||
print("1️⃣ Clearing PostgreSQL Vector Store...")
|
||||
clear_postgresql_vector_store()
|
||||
print()
|
||||
|
||||
# Clear DynamoDB tables
|
||||
print("2️⃣ Clearing DynamoDB Tables...")
|
||||
clear_dynamodb_tables()
|
||||
print()
|
||||
|
||||
# Clear credentials
|
||||
print("3️⃣ Clearing Stored Credentials...")
|
||||
clear_credentials()
|
||||
print()
|
||||
|
||||
print("=" * 60)
|
||||
print("✅ ALL DATA CLEARED SUCCESSFULLY")
|
||||
print("=" * 60)
|
||||
print()
|
||||
print("You can now start fresh by running the application:")
|
||||
print(" python app_dev.py")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
54
deploy.sh
Normal file
54
deploy.sh
Normal file
@@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
|
||||
# SharePoint Connector - AWS ECS Deployment Script
|
||||
|
||||
set -e
|
||||
|
||||
# Configuration
|
||||
AWS_REGION="ap-southeast-2"
|
||||
AWS_ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text)
|
||||
ECR_REPOSITORY="sharepoint-connector"
|
||||
ECS_CLUSTER="your-cluster-name"
|
||||
ECS_SERVICE="sharepoint-connector-service"
|
||||
TASK_FAMILY="sharepoint-connector"
|
||||
|
||||
echo "🚀 Deploying SharePoint Connector to AWS ECS"
|
||||
echo "📍 Region: $AWS_REGION"
|
||||
echo "🏷️ Account: $AWS_ACCOUNT_ID"
|
||||
echo ""
|
||||
|
||||
# Step 1: Build Docker image
|
||||
echo "🔨 Building Docker image..."
|
||||
docker build -t $ECR_REPOSITORY:latest .
|
||||
|
||||
# Step 2: Login to ECR
|
||||
echo "🔐 Logging in to ECR..."
|
||||
aws ecr get-login-password --region $AWS_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_REGION.amazonaws.com
|
||||
|
||||
# Step 3: Create ECR repository if it doesn't exist
|
||||
echo "📦 Ensuring ECR repository exists..."
|
||||
aws ecr describe-repositories --repository-names $ECR_REPOSITORY --region $AWS_REGION 2>/dev/null || \
|
||||
aws ecr create-repository --repository-name $ECR_REPOSITORY --region $AWS_REGION
|
||||
|
||||
# Step 4: Tag and push image
|
||||
echo "📤 Pushing image to ECR..."
|
||||
docker tag $ECR_REPOSITORY:latest $AWS_ACCOUNT_ID.dkr.ecr.$AWS_REGION.amazonaws.com/$ECR_REPOSITORY:latest
|
||||
docker push $AWS_ACCOUNT_ID.dkr.ecr.$AWS_REGION.amazonaws.com/$ECR_REPOSITORY:latest
|
||||
|
||||
# Step 5: Update task definition
|
||||
echo "📝 Updating ECS task definition..."
|
||||
TASK_DEFINITION=$(cat task-definition.json | sed "s/YOUR_ACCOUNT_ID/$AWS_ACCOUNT_ID/g")
|
||||
aws ecs register-task-definition --cli-input-json "$TASK_DEFINITION" --region $AWS_REGION
|
||||
|
||||
# Step 6: Update ECS service
|
||||
echo "🔄 Updating ECS service..."
|
||||
aws ecs update-service \
|
||||
--cluster $ECS_CLUSTER \
|
||||
--service $ECS_SERVICE \
|
||||
--task-definition $TASK_FAMILY \
|
||||
--force-new-deployment \
|
||||
--region $AWS_REGION
|
||||
|
||||
echo ""
|
||||
echo "✅ Deployment complete!"
|
||||
echo "📊 Check status: aws ecs describe-services --cluster $ECS_CLUSTER --services $ECS_SERVICE --region $AWS_REGION"
|
||||
30
docker-compose.yml
Normal file
30
docker-compose.yml
Normal file
@@ -0,0 +1,30 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# DynamoDB Local for development
|
||||
dynamodb-local:
|
||||
image: amazon/dynamodb-local
|
||||
container_name: dynamodb-local
|
||||
ports:
|
||||
- "8000:8000"
|
||||
command: "-jar DynamoDBLocal.jar -sharedDb -inMemory"
|
||||
|
||||
# Development app
|
||||
app-dev:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: sharepoint-connector-dev
|
||||
ports:
|
||||
- "5000:5000"
|
||||
environment:
|
||||
- AWS_REGION=ap-southeast-2
|
||||
- AWS_ACCESS_KEY_ID=fakeAccessKeyId
|
||||
- AWS_SECRET_ACCESS_KEY=fakeSecretAccessKey
|
||||
volumes:
|
||||
- ./app_dev.py:/app/app.py
|
||||
- ./templates:/app/templates
|
||||
- ./static:/app/static
|
||||
depends_on:
|
||||
- dynamodb-local
|
||||
command: python app.py
|
||||
245
document_parser.py
Normal file
245
document_parser.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Document Parser - Extract text from various file formats
|
||||
|
||||
Supports: PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx),
|
||||
CSV, text files, and more
|
||||
"""
|
||||
|
||||
import io
|
||||
import mimetypes
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DocumentParser:
|
||||
"""Parse different document types and extract text content for LLM processing."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize document parser."""
|
||||
self.supported_extensions = {
|
||||
# Text formats
|
||||
'.txt', '.md', '.csv', '.json', '.xml', '.html', '.css', '.js',
|
||||
'.py', '.java', '.cpp', '.c', '.h', '.sh', '.yaml', '.yml',
|
||||
# Microsoft Office
|
||||
'.docx', '.xlsx', '.pptx',
|
||||
# PDF
|
||||
'.pdf',
|
||||
# Other
|
||||
'.rtf', '.log'
|
||||
}
|
||||
|
||||
def can_parse(self, filename: str) -> bool:
|
||||
"""Check if file can be parsed."""
|
||||
ext = self._get_extension(filename)
|
||||
return ext in self.supported_extensions
|
||||
|
||||
def parse(self, content: bytes, filename: str) -> str:
|
||||
"""
|
||||
Parse document and extract text content.
|
||||
|
||||
Args:
|
||||
content: File content as bytes
|
||||
filename: Original filename (used to determine file type)
|
||||
|
||||
Returns:
|
||||
Extracted text content
|
||||
|
||||
Raises:
|
||||
ValueError: If file type is not supported
|
||||
"""
|
||||
ext = self._get_extension(filename)
|
||||
|
||||
if ext not in self.supported_extensions:
|
||||
raise ValueError(f"Unsupported file type: {ext}")
|
||||
|
||||
# Text files - direct decode
|
||||
if ext in {'.txt', '.md', '.csv', '.json', '.xml', '.html', '.css',
|
||||
'.js', '.py', '.java', '.cpp', '.c', '.h', '.sh', '.yaml',
|
||||
'.yml', '.log', '.rtf'}:
|
||||
return self._parse_text(content)
|
||||
|
||||
# PDF
|
||||
elif ext == '.pdf':
|
||||
return self._parse_pdf(content)
|
||||
|
||||
# Microsoft Word
|
||||
elif ext == '.docx':
|
||||
return self._parse_docx(content)
|
||||
|
||||
# Microsoft Excel
|
||||
elif ext == '.xlsx':
|
||||
return self._parse_xlsx(content)
|
||||
|
||||
# Microsoft PowerPoint
|
||||
elif ext == '.pptx':
|
||||
return self._parse_pptx(content)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Parser not implemented for: {ext}")
|
||||
|
||||
def _get_extension(self, filename: str) -> str:
|
||||
"""Get file extension in lowercase."""
|
||||
return '.' + filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
|
||||
|
||||
def _parse_text(self, content: bytes) -> str:
|
||||
"""Parse plain text files."""
|
||||
# Try multiple encodings
|
||||
for encoding in ['utf-8', 'latin-1', 'cp1252', 'ascii']:
|
||||
try:
|
||||
return content.decode(encoding)
|
||||
except (UnicodeDecodeError, AttributeError):
|
||||
continue
|
||||
|
||||
# Fallback: decode with errors='ignore'
|
||||
return content.decode('utf-8', errors='ignore')
|
||||
|
||||
def _parse_pdf(self, content: bytes) -> str:
|
||||
"""Parse PDF files."""
|
||||
try:
|
||||
import PyPDF2
|
||||
|
||||
pdf_file = io.BytesIO(content)
|
||||
reader = PyPDF2.PdfReader(pdf_file)
|
||||
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages):
|
||||
text = page.extract_text()
|
||||
if text.strip():
|
||||
text_parts.append(f"--- Page {page_num + 1} ---\n{text}")
|
||||
|
||||
return '\n\n'.join(text_parts) if text_parts else "(empty PDF)"
|
||||
|
||||
except ImportError:
|
||||
return "[PDF parsing requires PyPDF2: pip install PyPDF2]"
|
||||
except Exception as e:
|
||||
return f"[Error parsing PDF: {str(e)}]"
|
||||
|
||||
def _parse_docx(self, content: bytes) -> str:
|
||||
"""Parse Word documents (.docx)."""
|
||||
try:
|
||||
import docx
|
||||
|
||||
doc_file = io.BytesIO(content)
|
||||
doc = docx.Document(doc_file)
|
||||
|
||||
text_parts = []
|
||||
|
||||
# Extract paragraphs
|
||||
for para in doc.paragraphs:
|
||||
if para.text.strip():
|
||||
text_parts.append(para.text)
|
||||
|
||||
# Extract tables
|
||||
for table in doc.tables:
|
||||
table_text = []
|
||||
for row in table.rows:
|
||||
row_text = [cell.text.strip() for cell in row.cells]
|
||||
table_text.append(' | '.join(row_text))
|
||||
if table_text:
|
||||
text_parts.append('\n' + '\n'.join(table_text))
|
||||
|
||||
return '\n\n'.join(text_parts) if text_parts else "(empty document)"
|
||||
|
||||
except ImportError:
|
||||
return "[Word parsing requires python-docx: pip install python-docx]"
|
||||
except Exception as e:
|
||||
return f"[Error parsing Word document: {str(e)}]"
|
||||
|
||||
def _parse_xlsx(self, content: bytes) -> str:
|
||||
"""Parse Excel spreadsheets (.xlsx)."""
|
||||
try:
|
||||
import openpyxl
|
||||
|
||||
excel_file = io.BytesIO(content)
|
||||
workbook = openpyxl.load_workbook(excel_file, data_only=True)
|
||||
|
||||
text_parts = []
|
||||
|
||||
for sheet_name in workbook.sheetnames:
|
||||
sheet = workbook[sheet_name]
|
||||
|
||||
sheet_text = [f"=== Sheet: {sheet_name} ==="]
|
||||
|
||||
# Get data rows
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
# Skip empty rows
|
||||
if any(cell is not None for cell in row):
|
||||
row_text = ' | '.join(str(cell) if cell is not None else '' for cell in row)
|
||||
sheet_text.append(row_text)
|
||||
|
||||
if len(sheet_text) > 1: # Has content beyond header
|
||||
text_parts.append('\n'.join(sheet_text))
|
||||
|
||||
return '\n\n'.join(text_parts) if text_parts else "(empty spreadsheet)"
|
||||
|
||||
except ImportError:
|
||||
return "[Excel parsing requires openpyxl: pip install openpyxl]"
|
||||
except Exception as e:
|
||||
return f"[Error parsing Excel file: {str(e)}]"
|
||||
|
||||
def _parse_pptx(self, content: bytes) -> str:
|
||||
"""Parse PowerPoint presentations (.pptx)."""
|
||||
try:
|
||||
from pptx import Presentation
|
||||
|
||||
ppt_file = io.BytesIO(content)
|
||||
prs = Presentation(ppt_file)
|
||||
|
||||
text_parts = []
|
||||
|
||||
for slide_num, slide in enumerate(prs.slides, start=1):
|
||||
slide_text = [f"=== Slide {slide_num} ==="]
|
||||
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text") and shape.text.strip():
|
||||
slide_text.append(shape.text)
|
||||
|
||||
if len(slide_text) > 1: # Has content beyond header
|
||||
text_parts.append('\n'.join(slide_text))
|
||||
|
||||
return '\n\n'.join(text_parts) if text_parts else "(empty presentation)"
|
||||
|
||||
except ImportError:
|
||||
return "[PowerPoint parsing requires python-pptx: pip install python-pptx]"
|
||||
except Exception as e:
|
||||
return f"[Error parsing PowerPoint file: {str(e)}]"
|
||||
|
||||
|
||||
def get_file_info(filename: str, file_size: int) -> dict:
|
||||
"""Get human-readable file information."""
|
||||
ext = '.' + filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
|
||||
|
||||
# File type categories
|
||||
type_categories = {
|
||||
'document': {'.pdf', '.docx', '.doc', '.txt', '.rtf', '.odt'},
|
||||
'spreadsheet': {'.xlsx', '.xls', '.csv', '.ods'},
|
||||
'presentation': {'.pptx', '.ppt', '.odp'},
|
||||
'image': {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.svg', '.webp'},
|
||||
'video': {'.mp4', '.avi', '.mov', '.wmv', '.flv', '.mkv', '.webm'},
|
||||
'audio': {'.mp3', '.wav', '.flac', '.aac', '.ogg', '.wma'},
|
||||
'archive': {'.zip', '.rar', '.7z', '.tar', '.gz'},
|
||||
'code': {'.py', '.js', '.java', '.cpp', '.c', '.h', '.cs', '.php', '.rb'},
|
||||
'web': {'.html', '.css', '.xml', '.json', '.yaml', '.yml'},
|
||||
}
|
||||
|
||||
category = 'file'
|
||||
for cat, extensions in type_categories.items():
|
||||
if ext in extensions:
|
||||
category = cat
|
||||
break
|
||||
|
||||
# Format file size
|
||||
if file_size < 1024:
|
||||
size_str = f"{file_size} B"
|
||||
elif file_size < 1024 * 1024:
|
||||
size_str = f"{file_size / 1024:.1f} KB"
|
||||
elif file_size < 1024 * 1024 * 1024:
|
||||
size_str = f"{file_size / (1024 * 1024):.1f} MB"
|
||||
else:
|
||||
size_str = f"{file_size / (1024 * 1024 * 1024):.1f} GB"
|
||||
|
||||
return {
|
||||
'extension': ext,
|
||||
'category': category,
|
||||
'size_formatted': size_str,
|
||||
'size_bytes': file_size
|
||||
}
|
||||
23
iam-task-role-policy.json
Normal file
23
iam-task-role-policy.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": [
|
||||
"dynamodb:CreateTable",
|
||||
"dynamodb:DescribeTable",
|
||||
"dynamodb:GetItem",
|
||||
"dynamodb:PutItem",
|
||||
"dynamodb:UpdateItem",
|
||||
"dynamodb:Query",
|
||||
"dynamodb:Scan",
|
||||
"dynamodb:UpdateTimeToLive",
|
||||
"dynamodb:DescribeTimeToLive"
|
||||
],
|
||||
"Resource": [
|
||||
"arn:aws:dynamodb:ap-southeast-2:*:table/prod_*",
|
||||
"arn:aws:dynamodb:ap-southeast-2:*:table/prod_*/index/*"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
324
llm_client.py
Normal file
324
llm_client.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
LLM Client Abstraction Layer
|
||||
Supports multiple LLM providers with easy swapping
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Optional
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
class LLMClient(ABC):
|
||||
"""Abstract base class for LLM clients"""
|
||||
|
||||
@abstractmethod
|
||||
def chat(self, messages: List[Dict[str, str]], context: Optional[str] = None) -> str:
|
||||
"""
|
||||
Send messages to LLM and get response
|
||||
|
||||
Args:
|
||||
messages: List of {"role": "user"/"assistant", "content": "text"}
|
||||
context: Optional document context to include
|
||||
|
||||
Returns:
|
||||
LLM response text
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Check if LLM service is available"""
|
||||
pass
|
||||
|
||||
|
||||
class OllamaClient(LLMClient):
|
||||
"""Ollama LLM client (local deployment)"""
|
||||
|
||||
def __init__(self,
|
||||
base_url: str = "http://localhost:11434",
|
||||
model: str = "llama3.2",
|
||||
timeout: int = 120):
|
||||
"""
|
||||
Initialize Ollama client
|
||||
|
||||
Args:
|
||||
base_url: Ollama server URL
|
||||
model: Model name (e.g., llama3.2, mistral, codellama)
|
||||
timeout: Request timeout in seconds
|
||||
"""
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
|
||||
def chat(self, messages: List[Dict[str, str]], context: Optional[str] = None) -> str:
|
||||
"""Send chat request to Ollama"""
|
||||
|
||||
# Build prompt with context if provided
|
||||
formatted_messages = []
|
||||
|
||||
if context:
|
||||
# Add system message with document context
|
||||
formatted_messages.append({
|
||||
"role": "system",
|
||||
"content": f"You are a helpful assistant. Use the following document content to answer questions:\n\n{context}"
|
||||
})
|
||||
|
||||
# Add conversation history
|
||||
formatted_messages.extend(messages)
|
||||
|
||||
# Call Ollama API
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/api/chat",
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
"stream": False
|
||||
},
|
||||
timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
return result["message"]["content"]
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise Exception(f"Ollama API error: {str(e)}")
|
||||
|
||||
def chat_stream(self, messages: List[Dict[str, str]], context: Optional[str] = None):
|
||||
"""Send chat request to Ollama with streaming response"""
|
||||
|
||||
# Build prompt with context if provided
|
||||
formatted_messages = []
|
||||
|
||||
if context:
|
||||
# Add system message with document context
|
||||
formatted_messages.append({
|
||||
"role": "system",
|
||||
"content": f"You are a helpful assistant. Use the following document content to answer questions:\n\n{context}"
|
||||
})
|
||||
|
||||
# Add conversation history
|
||||
formatted_messages.extend(messages)
|
||||
|
||||
# Call Ollama API with streaming
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/api/chat",
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
"stream": True
|
||||
},
|
||||
timeout=self.timeout,
|
||||
stream=True
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Yield chunks as they arrive
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
if "message" in chunk and "content" in chunk["message"]:
|
||||
yield chunk["message"]["content"]
|
||||
if chunk.get("done", False):
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise Exception(f"Ollama API error: {str(e)}")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Ollama is running"""
|
||||
try:
|
||||
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
||||
return response.status_code == 200
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
class OpenAIClient(LLMClient):
|
||||
"""OpenAI API client (for easy swapping)"""
|
||||
|
||||
def __init__(self,
|
||||
api_key: str,
|
||||
model: str = "gpt-4",
|
||||
timeout: int = 120):
|
||||
"""
|
||||
Initialize OpenAI client
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model name (e.g., gpt-4, gpt-3.5-turbo)
|
||||
timeout: Request timeout in seconds
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
|
||||
def chat(self, messages: List[Dict[str, str]], context: Optional[str] = None) -> str:
|
||||
"""Send chat request to OpenAI"""
|
||||
|
||||
formatted_messages = []
|
||||
|
||||
if context:
|
||||
formatted_messages.append({
|
||||
"role": "system",
|
||||
"content": f"You are a helpful assistant. Use the following document content to answer questions:\n\n{context}"
|
||||
})
|
||||
|
||||
formatted_messages.extend(messages)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": formatted_messages
|
||||
},
|
||||
timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise Exception(f"OpenAI API error: {str(e)}")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if API key is valid"""
|
||||
try:
|
||||
response = requests.get(
|
||||
"https://api.openai.com/v1/models",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
timeout=5
|
||||
)
|
||||
return response.status_code == 200
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
class AnthropicClient(LLMClient):
|
||||
"""Anthropic Claude API client (for easy swapping)"""
|
||||
|
||||
def __init__(self,
|
||||
api_key: str,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
timeout: int = 120):
|
||||
"""
|
||||
Initialize Anthropic client
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
model: Model name
|
||||
timeout: Request timeout in seconds
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
|
||||
def chat(self, messages: List[Dict[str, str]], context: Optional[str] = None) -> str:
|
||||
"""Send chat request to Anthropic"""
|
||||
|
||||
formatted_messages = []
|
||||
system_prompt = None
|
||||
|
||||
if context:
|
||||
system_prompt = f"You are a helpful assistant. Use the following document content to answer questions:\n\n{context}"
|
||||
|
||||
formatted_messages.extend(messages)
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
"max_tokens": 4096
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
payload["system"] = system_prompt
|
||||
|
||||
response = requests.post(
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
headers={
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json=payload,
|
||||
timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
return result["content"][0]["text"]
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise Exception(f"Anthropic API error: {str(e)}")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if API key is valid"""
|
||||
# Anthropic doesn't have a simple health check, so we assume it's available
|
||||
return len(self.api_key) > 0
|
||||
|
||||
|
||||
def create_llm_client(provider: str = "ollama", **kwargs) -> LLMClient:
|
||||
"""
|
||||
Factory function to create LLM client
|
||||
|
||||
Args:
|
||||
provider: LLM provider name ("ollama", "openai", "anthropic")
|
||||
**kwargs: Provider-specific configuration
|
||||
|
||||
Returns:
|
||||
LLMClient instance
|
||||
|
||||
Example:
|
||||
# Ollama (default)
|
||||
client = create_llm_client("ollama", model="llama3.2")
|
||||
|
||||
# OpenAI
|
||||
client = create_llm_client("openai", api_key="sk-...", model="gpt-4")
|
||||
|
||||
# Anthropic
|
||||
client = create_llm_client("anthropic", api_key="sk-ant-...", model="claude-3-5-sonnet-20241022")
|
||||
"""
|
||||
|
||||
if provider.lower() == "ollama":
|
||||
return OllamaClient(
|
||||
base_url=kwargs.get("base_url", os.getenv("OLLAMA_URL", "http://localhost:11434")),
|
||||
model=kwargs.get("model", os.getenv("OLLAMA_MODEL", "llama3.2")),
|
||||
timeout=kwargs.get("timeout", 120)
|
||||
)
|
||||
|
||||
elif provider.lower() == "openai":
|
||||
api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY"))
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key required")
|
||||
return OpenAIClient(
|
||||
api_key=api_key,
|
||||
model=kwargs.get("model", "gpt-4"),
|
||||
timeout=kwargs.get("timeout", 120)
|
||||
)
|
||||
|
||||
elif provider.lower() == "anthropic":
|
||||
api_key = kwargs.get("api_key", os.getenv("ANTHROPIC_API_KEY"))
|
||||
if not api_key:
|
||||
raise ValueError("Anthropic API key required")
|
||||
return AnthropicClient(
|
||||
api_key=api_key,
|
||||
model=kwargs.get("model", "claude-3-5-sonnet-20241022"),
|
||||
timeout=kwargs.get("timeout", 120)
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown LLM provider: {provider}")
|
||||
23
requirements.txt
Normal file
23
requirements.txt
Normal file
@@ -0,0 +1,23 @@
|
||||
# SharePoint Connector Requirements
|
||||
|
||||
# Core dependencies
|
||||
boto3>=1.34.0
|
||||
cryptography>=41.0.0
|
||||
msal>=1.24.0
|
||||
requests>=2.31.0
|
||||
|
||||
# Web framework
|
||||
flask>=3.0.0
|
||||
|
||||
# Production server
|
||||
gunicorn>=21.2.0
|
||||
|
||||
# Document parsing
|
||||
PyPDF2>=3.0.0
|
||||
python-docx>=1.0.0
|
||||
openpyxl>=3.1.0
|
||||
python-pptx>=0.6.0
|
||||
|
||||
# PostgreSQL with pgvector for embeddings
|
||||
psycopg2-binary>=2.9.0
|
||||
pgvector>=0.2.0
|
||||
918
saas_connector_dynamodb.py
Normal file
918
saas_connector_dynamodb.py
Normal file
@@ -0,0 +1,918 @@
|
||||
"""
|
||||
Enterprise SaaS SharePoint Connector - DynamoDB Edition
|
||||
|
||||
Secure multi-tenant SharePoint connector using AWS DynamoDB for storage.
|
||||
Implements OAuth 2.0 with encrypted token storage, DynamoDB persistence,
|
||||
and enterprise security best practices.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import hashlib
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from decimal import Decimal
|
||||
|
||||
import requests
|
||||
from msal import ConfidentialClientApplication
|
||||
from cryptography.fernet import Fernet
|
||||
import boto3
|
||||
from boto3.dynamodb.conditions import Key, Attr
|
||||
from botocore.exceptions import ClientError
|
||||
from flask import Flask, request, redirect, session, jsonify, url_for
|
||||
from functools import wraps
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DYNAMODB HELPER FUNCTIONS
|
||||
# ============================================================================
|
||||
|
||||
def decimal_to_float(obj):
|
||||
"""Convert DynamoDB Decimal types to float/int for JSON serialization."""
|
||||
if isinstance(obj, list):
|
||||
return [decimal_to_float(i) for i in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {k: decimal_to_float(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, Decimal):
|
||||
return int(obj) if obj % 1 == 0 else float(obj)
|
||||
return obj
|
||||
|
||||
|
||||
def datetime_to_iso(dt: datetime) -> str:
|
||||
"""Convert datetime to ISO 8601 string."""
|
||||
return dt.isoformat() if dt else None
|
||||
|
||||
|
||||
def iso_to_datetime(iso_string: str) -> Optional[datetime]:
|
||||
"""Convert ISO 8601 string to datetime."""
|
||||
try:
|
||||
return datetime.fromisoformat(iso_string) if iso_string else None
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENCRYPTION SERVICE
|
||||
# ============================================================================
|
||||
|
||||
class TokenEncryption:
|
||||
"""
|
||||
Handles encryption/decryption of OAuth tokens.
|
||||
|
||||
SECURITY: Never store tokens in plain text!
|
||||
"""
|
||||
|
||||
def __init__(self, encryption_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize encryption service.
|
||||
|
||||
Args:
|
||||
encryption_key: Base64-encoded Fernet key. If not provided,
|
||||
uses ENCRYPTION_KEY env var.
|
||||
"""
|
||||
if encryption_key:
|
||||
self.key = encryption_key.encode()
|
||||
else:
|
||||
key_str = os.getenv('ENCRYPTION_KEY')
|
||||
if not key_str:
|
||||
raise ValueError("ENCRYPTION_KEY environment variable must be set")
|
||||
self.key = key_str.encode()
|
||||
|
||||
self.cipher = Fernet(self.key)
|
||||
|
||||
@staticmethod
|
||||
def generate_key() -> str:
|
||||
"""Generate a new encryption key. Store this securely!"""
|
||||
return Fernet.generate_key().decode()
|
||||
|
||||
def encrypt(self, plaintext: str) -> str:
|
||||
"""Encrypt a string."""
|
||||
if not plaintext:
|
||||
return ""
|
||||
return self.cipher.encrypt(plaintext.encode()).decode()
|
||||
|
||||
def decrypt(self, ciphertext: str) -> str:
|
||||
"""Decrypt a string."""
|
||||
if not ciphertext:
|
||||
return ""
|
||||
return self.cipher.decrypt(ciphertext.encode()).decode()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DYNAMODB DATA MODELS
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class SharePointConnectionInfo:
|
||||
"""Connection information for a SharePoint connection."""
|
||||
id: str
|
||||
user_id: str
|
||||
organization_id: Optional[str]
|
||||
connection_name: Optional[str]
|
||||
tenant_id: str
|
||||
microsoft_user_id: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: Optional[datetime]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DYNAMODB CONNECTOR
|
||||
# ============================================================================
|
||||
|
||||
class DynamoDBSharePointConnector:
|
||||
"""
|
||||
Enterprise-grade SharePoint connector using DynamoDB for storage.
|
||||
|
||||
DynamoDB Tables:
|
||||
1. sharepoint_connections - Stores encrypted OAuth tokens
|
||||
2. sharepoint_oauth_states - Temporary OAuth state for CSRF protection
|
||||
3. sharepoint_audit_logs - Audit trail for compliance
|
||||
|
||||
Features:
|
||||
- Multi-tenant OAuth 2.0
|
||||
- Encrypted token storage
|
||||
- Automatic token refresh
|
||||
- Audit logging
|
||||
- CSRF protection
|
||||
- Serverless-friendly (no connection pooling)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
redirect_uri: str,
|
||||
encryption_key: str,
|
||||
aws_region: str = "us-east-1",
|
||||
dynamodb_endpoint: Optional[str] = None, # For local development
|
||||
table_prefix: str = "",
|
||||
tenant_id: str = "common"
|
||||
):
|
||||
"""
|
||||
Initialize secure DynamoDB connector.
|
||||
|
||||
Args:
|
||||
client_id: Azure AD application ID
|
||||
client_secret: Azure AD client secret
|
||||
redirect_uri: OAuth callback URL (must be HTTPS in production!)
|
||||
encryption_key: Fernet encryption key for token storage
|
||||
aws_region: AWS region for DynamoDB
|
||||
dynamodb_endpoint: Custom endpoint (e.g., http://localhost:8000 for local)
|
||||
table_prefix: Prefix for table names (e.g., "prod_" or "dev_")
|
||||
tenant_id: Azure AD tenant or "common" for multi-tenant
|
||||
"""
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.tenant_id = tenant_id
|
||||
self.table_prefix = table_prefix
|
||||
|
||||
# Required scopes for SharePoint access
|
||||
# Note: offline_access is handled automatically by MSAL and should NOT be included
|
||||
self.scopes = [
|
||||
"https://graph.microsoft.com/Sites.Read.All",
|
||||
"https://graph.microsoft.com/Files.Read.All",
|
||||
"https://graph.microsoft.com/User.Read"
|
||||
]
|
||||
|
||||
# Initialize encryption
|
||||
self.encryption = TokenEncryption(encryption_key)
|
||||
|
||||
# Initialize DynamoDB
|
||||
if dynamodb_endpoint:
|
||||
# Local development
|
||||
self.dynamodb = boto3.resource('dynamodb',
|
||||
region_name=aws_region,
|
||||
endpoint_url=dynamodb_endpoint)
|
||||
else:
|
||||
# Production
|
||||
self.dynamodb = boto3.resource('dynamodb', region_name=aws_region)
|
||||
|
||||
# Table names
|
||||
self.connections_table_name = f"{table_prefix}sharepoint_connections"
|
||||
self.oauth_states_table_name = f"{table_prefix}sharepoint_oauth_states"
|
||||
self.audit_logs_table_name = f"{table_prefix}sharepoint_audit_logs"
|
||||
|
||||
# Initialize tables
|
||||
self._ensure_tables_exist()
|
||||
|
||||
# Initialize MSAL
|
||||
self.authority = f"https://login.microsoftonline.com/{tenant_id}"
|
||||
self.msal_app = ConfidentialClientApplication(
|
||||
client_id=self.client_id,
|
||||
client_credential=self.client_secret,
|
||||
authority=self.authority
|
||||
)
|
||||
|
||||
def _ensure_tables_exist(self):
|
||||
"""Create DynamoDB tables if they don't exist."""
|
||||
try:
|
||||
# Connections table
|
||||
self.connections_table = self.dynamodb.Table(self.connections_table_name)
|
||||
self.connections_table.load()
|
||||
except ClientError:
|
||||
self.connections_table = self._create_connections_table()
|
||||
|
||||
try:
|
||||
# OAuth states table
|
||||
self.oauth_states_table = self.dynamodb.Table(self.oauth_states_table_name)
|
||||
self.oauth_states_table.load()
|
||||
except ClientError:
|
||||
self.oauth_states_table = self._create_oauth_states_table()
|
||||
|
||||
try:
|
||||
# Audit logs table
|
||||
self.audit_logs_table = self.dynamodb.Table(self.audit_logs_table_name)
|
||||
self.audit_logs_table.load()
|
||||
except ClientError:
|
||||
self.audit_logs_table = self._create_audit_logs_table()
|
||||
|
||||
def _create_connections_table(self):
|
||||
"""Create SharePoint connections table."""
|
||||
table = self.dynamodb.create_table(
|
||||
TableName=self.connections_table_name,
|
||||
KeySchema=[
|
||||
{'AttributeName': 'id', 'KeyType': 'HASH'} # Partition key
|
||||
],
|
||||
AttributeDefinitions=[
|
||||
{'AttributeName': 'id', 'AttributeType': 'S'},
|
||||
{'AttributeName': 'user_id', 'AttributeType': 'S'},
|
||||
{'AttributeName': 'organization_id', 'AttributeType': 'S'}
|
||||
],
|
||||
GlobalSecondaryIndexes=[
|
||||
{
|
||||
'IndexName': 'user_id-index',
|
||||
'KeySchema': [
|
||||
{'AttributeName': 'user_id', 'KeyType': 'HASH'}
|
||||
],
|
||||
'Projection': {'ProjectionType': 'ALL'},
|
||||
'ProvisionedThroughput': {
|
||||
'ReadCapacityUnits': 5,
|
||||
'WriteCapacityUnits': 5
|
||||
}
|
||||
},
|
||||
{
|
||||
'IndexName': 'organization_id-index',
|
||||
'KeySchema': [
|
||||
{'AttributeName': 'organization_id', 'KeyType': 'HASH'}
|
||||
],
|
||||
'Projection': {'ProjectionType': 'ALL'},
|
||||
'ProvisionedThroughput': {
|
||||
'ReadCapacityUnits': 5,
|
||||
'WriteCapacityUnits': 5
|
||||
}
|
||||
}
|
||||
],
|
||||
BillingMode='PAY_PER_REQUEST' # On-demand pricing (or use ProvisionedThroughput)
|
||||
)
|
||||
table.wait_until_exists()
|
||||
return table
|
||||
|
||||
def _create_oauth_states_table(self):
|
||||
"""Create OAuth states table with TTL for automatic cleanup."""
|
||||
table = self.dynamodb.create_table(
|
||||
TableName=self.oauth_states_table_name,
|
||||
KeySchema=[
|
||||
{'AttributeName': 'state', 'KeyType': 'HASH'}
|
||||
],
|
||||
AttributeDefinitions=[
|
||||
{'AttributeName': 'state', 'AttributeType': 'S'}
|
||||
],
|
||||
BillingMode='PAY_PER_REQUEST'
|
||||
)
|
||||
table.wait_until_exists()
|
||||
|
||||
# Enable TTL for automatic cleanup
|
||||
try:
|
||||
self.dynamodb.meta.client.update_time_to_live(
|
||||
TableName=self.oauth_states_table_name,
|
||||
TimeToLiveSpecification={
|
||||
'Enabled': True,
|
||||
'AttributeName': 'ttl'
|
||||
}
|
||||
)
|
||||
except ClientError:
|
||||
pass # TTL already enabled or not supported
|
||||
|
||||
return table
|
||||
|
||||
def _create_audit_logs_table(self):
|
||||
"""Create audit logs table."""
|
||||
table = self.dynamodb.create_table(
|
||||
TableName=self.audit_logs_table_name,
|
||||
KeySchema=[
|
||||
{'AttributeName': 'connection_id', 'KeyType': 'HASH'},
|
||||
{'AttributeName': 'timestamp', 'KeyType': 'RANGE'} # Sort key
|
||||
],
|
||||
AttributeDefinitions=[
|
||||
{'AttributeName': 'connection_id', 'AttributeType': 'S'},
|
||||
{'AttributeName': 'timestamp', 'AttributeType': 'S'},
|
||||
{'AttributeName': 'user_id', 'AttributeType': 'S'}
|
||||
],
|
||||
GlobalSecondaryIndexes=[
|
||||
{
|
||||
'IndexName': 'user_id-timestamp-index',
|
||||
'KeySchema': [
|
||||
{'AttributeName': 'user_id', 'KeyType': 'HASH'},
|
||||
{'AttributeName': 'timestamp', 'KeyType': 'RANGE'}
|
||||
],
|
||||
'Projection': {'ProjectionType': 'ALL'},
|
||||
'ProvisionedThroughput': {
|
||||
'ReadCapacityUnits': 5,
|
||||
'WriteCapacityUnits': 5
|
||||
}
|
||||
}
|
||||
],
|
||||
BillingMode='PAY_PER_REQUEST'
|
||||
)
|
||||
table.wait_until_exists()
|
||||
return table
|
||||
|
||||
def initiate_connection(
|
||||
self,
|
||||
user_id: str,
|
||||
organization_id: Optional[str] = None,
|
||||
return_url: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Initiate OAuth flow for a user to connect their SharePoint.
|
||||
|
||||
Args:
|
||||
user_id: Your SaaS user ID
|
||||
organization_id: Your SaaS organization ID (if applicable)
|
||||
return_url: URL to redirect to after successful connection
|
||||
|
||||
Returns:
|
||||
Authorization URL to redirect user to
|
||||
"""
|
||||
# Generate secure state token
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
# Store state in DynamoDB for CSRF protection
|
||||
expires_at = datetime.utcnow() + timedelta(minutes=10)
|
||||
ttl = int(expires_at.timestamp()) # Unix timestamp for DynamoDB TTL
|
||||
|
||||
self.oauth_states_table.put_item(
|
||||
Item={
|
||||
'state': state,
|
||||
'user_id': user_id,
|
||||
'organization_id': organization_id or '',
|
||||
'return_url': return_url or '',
|
||||
'expires_at': datetime_to_iso(expires_at),
|
||||
'ttl': ttl, # DynamoDB will auto-delete after this time
|
||||
'used': False
|
||||
}
|
||||
)
|
||||
|
||||
# Generate authorization URL
|
||||
auth_url = self.msal_app.get_authorization_request_url(
|
||||
scopes=self.scopes,
|
||||
state=state,
|
||||
redirect_uri=self.redirect_uri
|
||||
)
|
||||
|
||||
return auth_url
|
||||
|
||||
def complete_connection(
|
||||
self,
|
||||
auth_code: str,
|
||||
state: str,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None
|
||||
) -> SharePointConnectionInfo:
|
||||
"""
|
||||
Complete OAuth flow and store connection.
|
||||
|
||||
Args:
|
||||
auth_code: Authorization code from OAuth callback
|
||||
state: State parameter from OAuth callback
|
||||
ip_address: User's IP address (for audit log)
|
||||
user_agent: User's user agent (for audit log)
|
||||
|
||||
Returns:
|
||||
Connection information
|
||||
|
||||
Raises:
|
||||
ValueError: If state is invalid or expired
|
||||
Exception: If token acquisition fails
|
||||
"""
|
||||
# Validate state (CSRF protection)
|
||||
try:
|
||||
response = self.oauth_states_table.get_item(Key={'state': state})
|
||||
oauth_state = response.get('Item')
|
||||
except ClientError:
|
||||
oauth_state = None
|
||||
|
||||
if not oauth_state:
|
||||
raise ValueError("Invalid OAuth state")
|
||||
|
||||
# Check if state is valid
|
||||
expires_at = iso_to_datetime(oauth_state.get('expires_at'))
|
||||
if oauth_state.get('used') or datetime.utcnow() > expires_at:
|
||||
raise ValueError("OAuth state expired or already used")
|
||||
|
||||
# Mark state as used
|
||||
self.oauth_states_table.update_item(
|
||||
Key={'state': state},
|
||||
UpdateExpression='SET used = :val',
|
||||
ExpressionAttributeValues={':val': True}
|
||||
)
|
||||
|
||||
# Exchange code for tokens
|
||||
token_response = self.msal_app.acquire_token_by_authorization_code(
|
||||
code=auth_code,
|
||||
scopes=self.scopes,
|
||||
redirect_uri=self.redirect_uri
|
||||
)
|
||||
|
||||
if "error" in token_response:
|
||||
raise Exception(f"Token acquisition failed: {token_response.get('error_description', token_response['error'])}")
|
||||
|
||||
# Get user info from Microsoft
|
||||
user_info = self._get_user_info(token_response["access_token"])
|
||||
|
||||
# Calculate token expiry
|
||||
expires_in = token_response.get("expires_in", 3600)
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
# Encrypt tokens
|
||||
encrypted_access_token = self.encryption.encrypt(token_response["access_token"])
|
||||
encrypted_refresh_token = self.encryption.encrypt(token_response.get("refresh_token", ""))
|
||||
|
||||
# Create connection record
|
||||
user_id = oauth_state['user_id']
|
||||
organization_id = oauth_state.get('organization_id') or None
|
||||
connection_id = self._generate_connection_id(user_id, user_info["id"])
|
||||
|
||||
now = datetime.utcnow()
|
||||
connection_item = {
|
||||
'id': connection_id,
|
||||
'user_id': user_id,
|
||||
'organization_id': organization_id or '',
|
||||
'tenant_id': user_info.get('tenantId', self.tenant_id),
|
||||
'microsoft_user_id': user_info["id"],
|
||||
'connection_name': f"{user_info.get('displayName', 'SharePoint')} - {user_info.get('userPrincipalName', '')}",
|
||||
'encrypted_access_token': encrypted_access_token,
|
||||
'encrypted_refresh_token': encrypted_refresh_token,
|
||||
'token_expires_at': datetime_to_iso(expires_at),
|
||||
'scopes': json.dumps(self.scopes),
|
||||
'is_active': True,
|
||||
'created_at': datetime_to_iso(now),
|
||||
'updated_at': datetime_to_iso(now),
|
||||
'last_used_at': datetime_to_iso(now)
|
||||
}
|
||||
|
||||
# Store connection
|
||||
self.connections_table.put_item(Item=connection_item)
|
||||
|
||||
# Audit log
|
||||
self._log_activity(
|
||||
connection_id=connection_id,
|
||||
user_id=user_id,
|
||||
action="connection_created",
|
||||
status="success",
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
return SharePointConnectionInfo(
|
||||
id=connection_id,
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
connection_name=connection_item['connection_name'],
|
||||
tenant_id=connection_item['tenant_id'],
|
||||
microsoft_user_id=user_info["id"],
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
last_used_at=now
|
||||
)
|
||||
|
||||
def get_valid_token(
|
||||
self,
|
||||
connection_id: str,
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Get a valid access token, refreshing if necessary.
|
||||
|
||||
Args:
|
||||
connection_id: SharePoint connection ID
|
||||
user_id: Your SaaS user ID (for authorization check)
|
||||
|
||||
Returns:
|
||||
Valid access token
|
||||
|
||||
Raises:
|
||||
ValueError: If connection not found or doesn't belong to user
|
||||
Exception: If token refresh fails
|
||||
"""
|
||||
# Get connection (with authorization check!)
|
||||
try:
|
||||
response = self.connections_table.get_item(Key={'id': connection_id})
|
||||
connection = response.get('Item')
|
||||
except ClientError:
|
||||
connection = None
|
||||
|
||||
if not connection or connection.get('user_id') != user_id or not connection.get('is_active'):
|
||||
raise ValueError("Connection not found or access denied")
|
||||
|
||||
# Check if token needs refresh
|
||||
token_expires_at = iso_to_datetime(connection['token_expires_at'])
|
||||
if datetime.utcnow() >= token_expires_at - timedelta(minutes=5):
|
||||
# Refresh token
|
||||
refresh_token = self.encryption.decrypt(connection['encrypted_refresh_token'])
|
||||
|
||||
token_response = self.msal_app.acquire_token_by_refresh_token(
|
||||
refresh_token=refresh_token,
|
||||
scopes=self.scopes
|
||||
)
|
||||
|
||||
if "error" in token_response:
|
||||
self._log_activity(
|
||||
connection_id=connection_id,
|
||||
user_id=user_id,
|
||||
action="token_refresh",
|
||||
status="failure",
|
||||
details=json.dumps({"error": token_response.get("error")})
|
||||
)
|
||||
raise Exception(f"Token refresh failed: {token_response.get('error_description', token_response['error'])}")
|
||||
|
||||
# Update stored tokens
|
||||
new_access_token = self.encryption.encrypt(token_response["access_token"])
|
||||
new_refresh_token = token_response.get("refresh_token")
|
||||
expires_in = token_response.get("expires_in", 3600)
|
||||
new_expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
|
||||
update_expression = "SET encrypted_access_token = :access, token_expires_at = :expires, updated_at = :updated, last_used_at = :used"
|
||||
expression_values = {
|
||||
':access': new_access_token,
|
||||
':expires': datetime_to_iso(new_expires_at),
|
||||
':updated': datetime_to_iso(datetime.utcnow()),
|
||||
':used': datetime_to_iso(datetime.utcnow())
|
||||
}
|
||||
|
||||
if new_refresh_token:
|
||||
update_expression += ", encrypted_refresh_token = :refresh"
|
||||
expression_values[':refresh'] = self.encryption.encrypt(new_refresh_token)
|
||||
|
||||
self.connections_table.update_item(
|
||||
Key={'id': connection_id},
|
||||
UpdateExpression=update_expression,
|
||||
ExpressionAttributeValues=expression_values
|
||||
)
|
||||
|
||||
self._log_activity(
|
||||
connection_id=connection_id,
|
||||
user_id=user_id,
|
||||
action="token_refresh",
|
||||
status="success"
|
||||
)
|
||||
|
||||
connection['encrypted_access_token'] = new_access_token
|
||||
|
||||
else:
|
||||
# Update last used timestamp
|
||||
self.connections_table.update_item(
|
||||
Key={'id': connection_id},
|
||||
UpdateExpression='SET last_used_at = :val',
|
||||
ExpressionAttributeValues={':val': datetime_to_iso(datetime.utcnow())}
|
||||
)
|
||||
|
||||
# Decrypt and return access token
|
||||
return self.encryption.decrypt(connection['encrypted_access_token'])
|
||||
|
||||
def list_connections(
|
||||
self,
|
||||
user_id: str,
|
||||
organization_id: Optional[str] = None
|
||||
) -> List[SharePointConnectionInfo]:
|
||||
"""
|
||||
List all SharePoint connections for a user or organization.
|
||||
|
||||
Args:
|
||||
user_id: Your SaaS user ID
|
||||
organization_id: Your SaaS organization ID (optional)
|
||||
|
||||
Returns:
|
||||
List of connections
|
||||
"""
|
||||
# Query by user_id using GSI
|
||||
response = self.connections_table.query(
|
||||
IndexName='user_id-index',
|
||||
KeyConditionExpression=Key('user_id').eq(user_id),
|
||||
FilterExpression=Attr('is_active').eq(True)
|
||||
)
|
||||
|
||||
connections = response.get('Items', [])
|
||||
|
||||
# Filter by organization if specified
|
||||
if organization_id:
|
||||
connections = [c for c in connections if c.get('organization_id') == organization_id]
|
||||
|
||||
return [
|
||||
SharePointConnectionInfo(
|
||||
id=conn['id'],
|
||||
user_id=conn['user_id'],
|
||||
organization_id=conn.get('organization_id') or None,
|
||||
connection_name=conn.get('connection_name'),
|
||||
tenant_id=conn.get('tenant_id'),
|
||||
microsoft_user_id=conn.get('microsoft_user_id'),
|
||||
is_active=conn.get('is_active', False),
|
||||
created_at=iso_to_datetime(conn.get('created_at')),
|
||||
last_used_at=iso_to_datetime(conn.get('last_used_at'))
|
||||
)
|
||||
for conn in connections
|
||||
]
|
||||
|
||||
def disconnect(
|
||||
self,
|
||||
connection_id: str,
|
||||
user_id: str,
|
||||
ip_address: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Disconnect (deactivate) a SharePoint connection.
|
||||
|
||||
Args:
|
||||
connection_id: SharePoint connection ID
|
||||
user_id: Your SaaS user ID (for authorization)
|
||||
ip_address: User's IP for audit log
|
||||
"""
|
||||
# Get connection to verify ownership
|
||||
try:
|
||||
response = self.connections_table.get_item(Key={'id': connection_id})
|
||||
connection = response.get('Item')
|
||||
except ClientError:
|
||||
connection = None
|
||||
|
||||
if not connection or connection.get('user_id') != user_id:
|
||||
raise ValueError("Connection not found or access denied")
|
||||
|
||||
# Deactivate connection
|
||||
self.connections_table.update_item(
|
||||
Key={'id': connection_id},
|
||||
UpdateExpression='SET is_active = :val, updated_at = :updated',
|
||||
ExpressionAttributeValues={
|
||||
':val': False,
|
||||
':updated': datetime_to_iso(datetime.utcnow())
|
||||
}
|
||||
)
|
||||
|
||||
self._log_activity(
|
||||
connection_id=connection_id,
|
||||
user_id=user_id,
|
||||
action="connection_disconnected",
|
||||
status="success",
|
||||
ip_address=ip_address
|
||||
)
|
||||
|
||||
def _get_user_info(self, access_token: str) -> Dict[str, Any]:
|
||||
"""Get Microsoft user information."""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
response = requests.get(
|
||||
"https://graph.microsoft.com/v1.0/me?$select=id,displayName,userPrincipalName,mail",
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _generate_connection_id(self, user_id: str, microsoft_user_id: str) -> str:
|
||||
"""Generate deterministic connection ID."""
|
||||
combined = f"{user_id}:{microsoft_user_id}"
|
||||
return hashlib.sha256(combined.encode()).hexdigest()[:32]
|
||||
|
||||
def _log_activity(
|
||||
self,
|
||||
connection_id: str,
|
||||
user_id: str,
|
||||
action: str,
|
||||
status: str,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
details: Optional[str] = None
|
||||
):
|
||||
"""Create audit log entry."""
|
||||
timestamp = datetime.utcnow()
|
||||
self.audit_logs_table.put_item(
|
||||
Item={
|
||||
'connection_id': connection_id,
|
||||
'timestamp': datetime_to_iso(timestamp),
|
||||
'user_id': user_id,
|
||||
'action': action,
|
||||
'status': status,
|
||||
'ip_address': ip_address or '',
|
||||
'user_agent': user_agent or '',
|
||||
'details': details or ''
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SHAREPOINT API CLIENT
|
||||
# ============================================================================
|
||||
|
||||
class SecureSharePointClient:
|
||||
"""
|
||||
SharePoint API client with automatic token management.
|
||||
|
||||
Use this to make SharePoint API calls on behalf of connected users.
|
||||
"""
|
||||
|
||||
def __init__(self, connector: DynamoDBSharePointConnector, connection_id: str, user_id: str):
|
||||
"""
|
||||
Initialize client.
|
||||
|
||||
Args:
|
||||
connector: DynamoDBSharePointConnector instance
|
||||
connection_id: SharePoint connection ID
|
||||
user_id: Your SaaS user ID
|
||||
"""
|
||||
self.connector = connector
|
||||
self.connection_id = connection_id
|
||||
self.user_id = user_id
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get headers with valid access token."""
|
||||
access_token = self.connector.get_valid_token(self.connection_id, self.user_id)
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
def list_sites(self) -> List[Dict[str, Any]]:
|
||||
"""List SharePoint sites the user has access to."""
|
||||
response = requests.get(
|
||||
"https://graph.microsoft.com/v1.0/sites?search=*",
|
||||
headers=self._get_headers()
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get("value", [])
|
||||
|
||||
def read_file(self, site_id: str, file_path: str, as_text: bool = True) -> Any:
|
||||
"""Read file content."""
|
||||
encoded_path = requests.utils.quote(file_path)
|
||||
url = f"https://graph.microsoft.com/v1.0/sites/{site_id}/drive/root:/{encoded_path}:/content"
|
||||
|
||||
response = requests.get(url, headers=self._get_headers())
|
||||
response.raise_for_status()
|
||||
|
||||
return response.text if as_text else response.content
|
||||
|
||||
def list_files(self, site_id: str, folder_path: str = "") -> List[Dict[str, Any]]:
|
||||
"""List files in a folder."""
|
||||
if folder_path:
|
||||
encoded_path = requests.utils.quote(folder_path)
|
||||
url = f"https://graph.microsoft.com/v1.0/sites/{site_id}/drive/root:/{encoded_path}:/children"
|
||||
else:
|
||||
url = f"https://graph.microsoft.com/v1.0/sites/{site_id}/drive/root/children"
|
||||
|
||||
response = requests.get(url, headers=self._get_headers())
|
||||
response.raise_for_status()
|
||||
return response.json().get("value", [])
|
||||
|
||||
def search_files(self, site_id: str, query: str) -> List[Dict[str, Any]]:
|
||||
"""Search for files."""
|
||||
url = f"https://graph.microsoft.com/v1.0/sites/{site_id}/drive/root/search(q='{query}')"
|
||||
response = requests.get(url, headers=self._get_headers())
|
||||
response.raise_for_status()
|
||||
return response.json().get("value", [])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FLASK INTEGRATION EXAMPLE
|
||||
# ============================================================================
|
||||
|
||||
def create_app(connector: DynamoDBSharePointConnector) -> Flask:
|
||||
"""Create Flask app with DynamoDB SharePoint connector integration."""
|
||||
|
||||
app = Flask(__name__)
|
||||
app.secret_key = os.getenv("FLASK_SECRET_KEY", secrets.token_urlsafe(32))
|
||||
|
||||
def require_auth(f):
|
||||
"""Decorator to require authentication."""
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
if "user_id" not in session:
|
||||
return jsonify({"error": "Authentication required"}), 401
|
||||
return f(*args, **kwargs)
|
||||
return decorated_function
|
||||
|
||||
@app.route("/sharepoint/connect")
|
||||
@require_auth
|
||||
def connect_sharepoint():
|
||||
"""Initiate SharePoint connection."""
|
||||
user_id = session["user_id"]
|
||||
organization_id = session.get("organization_id")
|
||||
|
||||
auth_url = connector.initiate_connection(
|
||||
user_id=user_id,
|
||||
organization_id=organization_id,
|
||||
return_url=request.args.get("return_url", "/dashboard")
|
||||
)
|
||||
|
||||
return redirect(auth_url)
|
||||
|
||||
@app.route("/sharepoint/callback")
|
||||
def sharepoint_callback():
|
||||
"""OAuth callback endpoint."""
|
||||
if "error" in request.args:
|
||||
return jsonify({
|
||||
"error": request.args.get("error_description", request.args["error"])
|
||||
}), 400
|
||||
|
||||
auth_code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
|
||||
if not auth_code or not state:
|
||||
return jsonify({"error": "Invalid callback"}), 400
|
||||
|
||||
try:
|
||||
connection_info = connector.complete_connection(
|
||||
auth_code=auth_code,
|
||||
state=state,
|
||||
ip_address=request.remote_addr,
|
||||
user_agent=request.headers.get("User-Agent")
|
||||
)
|
||||
|
||||
session["sharepoint_connection_id"] = connection_info.id
|
||||
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"connection": {
|
||||
"id": connection_info.id,
|
||||
"name": connection_info.connection_name
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
@app.route("/api/sharepoint/connections")
|
||||
@require_auth
|
||||
def list_connections():
|
||||
"""List user's SharePoint connections."""
|
||||
user_id = session["user_id"]
|
||||
connections = connector.list_connections(user_id)
|
||||
|
||||
return jsonify({
|
||||
"connections": [
|
||||
{
|
||||
"id": conn.id,
|
||||
"name": conn.connection_name,
|
||||
"created_at": conn.created_at.isoformat(),
|
||||
"last_used_at": conn.last_used_at.isoformat() if conn.last_used_at else None
|
||||
}
|
||||
for conn in connections
|
||||
]
|
||||
})
|
||||
|
||||
@app.route("/api/sharepoint/connections/<connection_id>/disconnect", methods=["POST"])
|
||||
@require_auth
|
||||
def disconnect_sharepoint(connection_id):
|
||||
"""Disconnect SharePoint."""
|
||||
user_id = session["user_id"]
|
||||
|
||||
try:
|
||||
connector.disconnect(
|
||||
connection_id=connection_id,
|
||||
user_id=user_id,
|
||||
ip_address=request.remote_addr
|
||||
)
|
||||
return jsonify({"success": True})
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 400
|
||||
|
||||
@app.route("/api/sharepoint/<connection_id>/sites")
|
||||
@require_auth
|
||||
def get_sites(connection_id):
|
||||
"""Get SharePoint sites."""
|
||||
user_id = session["user_id"]
|
||||
|
||||
try:
|
||||
client = SecureSharePointClient(connector, connection_id, user_id)
|
||||
sites = client.list_sites()
|
||||
return jsonify({"sites": decimal_to_float(sites)})
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize connector
|
||||
connector = DynamoDBSharePointConnector(
|
||||
client_id=os.getenv("SHAREPOINT_CLIENT_ID"),
|
||||
client_secret=os.getenv("SHAREPOINT_CLIENT_SECRET"),
|
||||
redirect_uri=os.getenv("REDIRECT_URI"),
|
||||
encryption_key=os.getenv("ENCRYPTION_KEY"),
|
||||
aws_region=os.getenv("AWS_REGION", "us-east-1"),
|
||||
table_prefix=os.getenv("TABLE_PREFIX", ""),
|
||||
tenant_id=os.getenv("SHAREPOINT_TENANT_ID", "common")
|
||||
)
|
||||
|
||||
# Create and run Flask app
|
||||
app = create_app(connector)
|
||||
app.run(debug=False, port=5000)
|
||||
23
setup_agent.py
Normal file
23
setup_agent.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from toothfairy_client import ToothFairyClient
|
||||
|
||||
# 1. Manually load the .env file so the script can see your keys
|
||||
load_dotenv()
|
||||
|
||||
# 2. Initialize the client
|
||||
client = ToothFairyClient()
|
||||
|
||||
# 3. Create the Tool pulling the URL from .env
|
||||
ngrok_url = os.getenv("NGROK_URL")
|
||||
|
||||
if not ngrok_url:
|
||||
print("❌ Error: NGROK_URL not found in .env file. Please add it and try again.")
|
||||
exit(1)
|
||||
|
||||
tool_id = client.create_search_tool(ngrok_url)
|
||||
print(f"✅ Tool created with ID: {tool_id}")
|
||||
|
||||
# 4. Create the Agent and attach the Tool
|
||||
agent_id = client.create_agent(tool_id)
|
||||
print(f"✅ Agent created with ID: {agent_id}")
|
||||
30
sonar-project.properties
Normal file
30
sonar-project.properties
Normal file
@@ -0,0 +1,30 @@
|
||||
# SonarQube Project Configuration for SharePoint Connector
|
||||
|
||||
# Project identification
|
||||
sonar.projectKey=sharepoint-connector
|
||||
sonar.projectName=SharePoint Connector Plugin
|
||||
sonar.projectVersion=1.0
|
||||
|
||||
# Source code
|
||||
sonar.sources=.
|
||||
sonar.sourceEncoding=UTF-8
|
||||
|
||||
# Tests
|
||||
sonar.tests=.
|
||||
sonar.test.inclusions=test_*.py
|
||||
|
||||
# Exclusions
|
||||
sonar.exclusions=venv/**,**/__pycache__/**,*.pyc,.venv/**,htmlcov/**,templates/**,static/**,**/migrations/**
|
||||
|
||||
# Python specific
|
||||
sonar.python.version=3.11
|
||||
sonar.python.coverage.reportPaths=coverage.xml
|
||||
|
||||
# Coverage exclusions (files that don't need coverage)
|
||||
sonar.coverage.exclusions=test_*.py,**/__init__.py,**/migrations/**,clear_data.py
|
||||
|
||||
# Duplications
|
||||
sonar.cpd.exclusions=test_*.py
|
||||
|
||||
# Analysis parameters
|
||||
sonar.scm.provider=git
|
||||
1310
static/style.css
Normal file
1310
static/style.css
Normal file
File diff suppressed because it is too large
Load Diff
155
storage/credentials_storage.py
Normal file
155
storage/credentials_storage.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
Secure Credentials Storage
|
||||
|
||||
Encrypts and stores SharePoint OAuth credentials to disk for persistence across restarts.
|
||||
Uses Fernet symmetric encryption to protect sensitive data.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from cryptography.fernet import Fernet
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CredentialsStorage:
|
||||
"""
|
||||
Secure storage for SharePoint OAuth credentials.
|
||||
|
||||
Stores credentials encrypted on disk at ~/.sharepoint_credentials/credentials.enc
|
||||
Each user's credentials are stored separately and encrypted.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: str = None):
|
||||
"""
|
||||
Initialize credentials storage.
|
||||
|
||||
Args:
|
||||
storage_dir: Directory to store credentials (default: ~/.sharepoint_credentials)
|
||||
"""
|
||||
if storage_dir is None:
|
||||
storage_dir = os.path.expanduser("~/.sharepoint_credentials")
|
||||
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.credentials_file = self.storage_dir / "credentials.enc"
|
||||
self.key_file = self.storage_dir / "key.key"
|
||||
|
||||
# Initialize encryption key
|
||||
self.cipher = self._get_or_create_cipher()
|
||||
|
||||
# Load existing credentials
|
||||
self.credentials: Dict[str, Dict] = self._load_credentials()
|
||||
|
||||
def _get_or_create_cipher(self) -> Fernet:
|
||||
"""Get or create encryption key."""
|
||||
if self.key_file.exists():
|
||||
with open(self.key_file, 'rb') as f:
|
||||
key = f.read()
|
||||
else:
|
||||
key = Fernet.generate_key()
|
||||
with open(self.key_file, 'wb') as f:
|
||||
f.write(key)
|
||||
# Set restrictive permissions on key file
|
||||
os.chmod(self.key_file, 0o600)
|
||||
|
||||
return Fernet(key)
|
||||
|
||||
def _load_credentials(self) -> Dict[str, Dict]:
|
||||
"""Load credentials from disk."""
|
||||
if not self.credentials_file.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(self.credentials_file, 'rb') as f:
|
||||
encrypted_data = f.read()
|
||||
|
||||
if not encrypted_data:
|
||||
return {}
|
||||
|
||||
decrypted_data = self.cipher.decrypt(encrypted_data)
|
||||
return json.loads(decrypted_data.decode('utf-8'))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load credentials: {e}")
|
||||
return {}
|
||||
|
||||
def _save_credentials(self):
|
||||
"""Save credentials to disk."""
|
||||
try:
|
||||
data = json.dumps(self.credentials).encode('utf-8')
|
||||
encrypted_data = self.cipher.encrypt(data)
|
||||
|
||||
with open(self.credentials_file, 'wb') as f:
|
||||
f.write(encrypted_data)
|
||||
|
||||
# Set restrictive permissions on credentials file
|
||||
os.chmod(self.credentials_file, 0o600)
|
||||
|
||||
logger.info("Credentials saved to disk")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save credentials: {e}")
|
||||
raise
|
||||
|
||||
def save_config(self, user_id: str, config: Dict) -> None:
|
||||
"""
|
||||
Save user configuration.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
config: Configuration dictionary containing:
|
||||
- tenant_id: Azure tenant ID
|
||||
- client_id: Azure client ID
|
||||
- client_secret: Azure client secret
|
||||
"""
|
||||
self.credentials[user_id] = config
|
||||
self._save_credentials()
|
||||
logger.info(f"Saved credentials for user {user_id}")
|
||||
|
||||
def get_config(self, user_id: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get user configuration.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
Configuration dictionary or None if not found
|
||||
"""
|
||||
return self.credentials.get(user_id)
|
||||
|
||||
def delete_config(self, user_id: str) -> bool:
|
||||
"""
|
||||
Delete user configuration.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
if user_id in self.credentials:
|
||||
del self.credentials[user_id]
|
||||
self._save_credentials()
|
||||
logger.info(f"Deleted credentials for user {user_id}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_users(self) -> list:
|
||||
"""Get list of users with stored credentials."""
|
||||
return list(self.credentials.keys())
|
||||
|
||||
|
||||
# Global credentials storage instance
|
||||
_credentials_storage = None
|
||||
|
||||
|
||||
def get_credentials_storage() -> CredentialsStorage:
|
||||
"""Get the global credentials storage instance."""
|
||||
global _credentials_storage
|
||||
if _credentials_storage is None:
|
||||
_credentials_storage = CredentialsStorage()
|
||||
return _credentials_storage
|
||||
57
task-definition.json
Normal file
57
task-definition.json
Normal file
@@ -0,0 +1,57 @@
|
||||
{
|
||||
"family": "sharepoint-connector",
|
||||
"networkMode": "awsvpc",
|
||||
"requiresCompatibilities": ["FARGATE"],
|
||||
"cpu": "512",
|
||||
"memory": "1024",
|
||||
"executionRoleArn": "arn:aws:iam::YOUR_ACCOUNT_ID:role/ecsTaskExecutionRole",
|
||||
"taskRoleArn": "arn:aws:iam::YOUR_ACCOUNT_ID:role/sharepoint-connector-task-role",
|
||||
"containerDefinitions": [
|
||||
{
|
||||
"name": "sharepoint-connector",
|
||||
"image": "YOUR_ACCOUNT_ID.dkr.ecr.ap-southeast-2.amazonaws.com/sharepoint-connector:latest",
|
||||
"essential": true,
|
||||
"portMappings": [
|
||||
{
|
||||
"containerPort": 8000,
|
||||
"protocol": "tcp"
|
||||
}
|
||||
],
|
||||
"environment": [
|
||||
{
|
||||
"name": "AWS_REGION",
|
||||
"value": "ap-southeast-2"
|
||||
},
|
||||
{
|
||||
"name": "TABLE_PREFIX",
|
||||
"value": "prod_"
|
||||
},
|
||||
{
|
||||
"name": "PORT",
|
||||
"value": "8000"
|
||||
}
|
||||
],
|
||||
"secrets": [
|
||||
{
|
||||
"name": "FLASK_SECRET_KEY",
|
||||
"valueFrom": "arn:aws:secretsmanager:ap-southeast-2:YOUR_ACCOUNT_ID:secret:sharepoint-connector/flask-secret"
|
||||
}
|
||||
],
|
||||
"healthCheck": {
|
||||
"command": ["CMD-SHELL", "python -c \"import requests; requests.get('http://localhost:8000/health')\""],
|
||||
"interval": 30,
|
||||
"timeout": 5,
|
||||
"retries": 3,
|
||||
"startPeriod": 60
|
||||
},
|
||||
"logConfiguration": {
|
||||
"logDriver": "awslogs",
|
||||
"options": {
|
||||
"awslogs-group": "/ecs/sharepoint-connector",
|
||||
"awslogs-region": "ap-southeast-2",
|
||||
"awslogs-stream-prefix": "ecs"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
25
templates/error.html
Normal file
25
templates/error.html
Normal file
@@ -0,0 +1,25 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Error - SharePoint Connector</title>
|
||||
<link rel="stylesheet" href="/static/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1>❌ Error</h1>
|
||||
</header>
|
||||
|
||||
<div class="status-card">
|
||||
<div class="error">
|
||||
<h3>Something went wrong</h3>
|
||||
<p>{{ error }}</p>
|
||||
<br>
|
||||
<a href="/" class="btn btn-primary">Go Home</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
1417
templates/index.html
Normal file
1417
templates/index.html
Normal file
File diff suppressed because it is too large
Load Diff
286
test_app.py
Normal file
286
test_app.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Unit tests for app.py Flask routes
|
||||
"""
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from app import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create a test client for the Flask app."""
|
||||
app.config['TESTING'] = True
|
||||
app.config['SECRET_KEY'] = 'test-secret-key'
|
||||
with app.test_client() as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Create a mock session."""
|
||||
return {"user_id": "test_user_123"}
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Test health check endpoint."""
|
||||
|
||||
def test_health_endpoint(self, client):
|
||||
"""Test /health endpoint returns healthy status."""
|
||||
response = client.get('/health')
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['status'] == 'healthy'
|
||||
|
||||
|
||||
class TestIndexRoute:
|
||||
"""Test index route."""
|
||||
|
||||
def test_index_returns_template(self, client):
|
||||
"""Test index route returns 200."""
|
||||
with patch('app.render_template') as mock_render:
|
||||
mock_render.return_value = '<html>Test</html>'
|
||||
response = client.get('/')
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestConfigurationAPI:
|
||||
"""Test configuration API endpoints."""
|
||||
|
||||
def test_check_config_not_configured(self, client):
|
||||
"""Test check_config returns false when not configured."""
|
||||
response = client.get('/api/config/check')
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert 'configured' in data
|
||||
|
||||
def test_save_config_missing_credentials(self, client):
|
||||
"""Test save_config fails without credentials."""
|
||||
response = client.post(
|
||||
'/api/config/save',
|
||||
data=json.dumps({}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
@patch('app.credentials_storage')
|
||||
def test_save_config_success(self, mock_storage, client):
|
||||
"""Test save_config succeeds with valid credentials."""
|
||||
mock_storage.get_config.return_value = None
|
||||
mock_storage.save_config.return_value = None
|
||||
|
||||
response = client.post(
|
||||
'/api/config/save',
|
||||
data=json.dumps({
|
||||
'client_id': 'test-client-id',
|
||||
'client_secret': 'test-client-secret',
|
||||
'tenant_id': 'common'
|
||||
}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['success'] is True
|
||||
|
||||
@patch('app.credentials_storage')
|
||||
def test_reset_config(self, mock_storage, client):
|
||||
"""Test reset_config clears configuration."""
|
||||
mock_storage.delete_config.return_value = None
|
||||
|
||||
response = client.post('/api/config/reset')
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['success'] is True
|
||||
|
||||
|
||||
class TestSharePointAPI:
|
||||
"""Test SharePoint API endpoints."""
|
||||
|
||||
def test_list_connections_not_configured(self, client):
|
||||
"""Test list_connections fails without configuration."""
|
||||
response = client.get('/api/sharepoint/connections')
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
def test_get_sites_missing_connection_id(self, client):
|
||||
"""Test get_sites requires valid connection_id."""
|
||||
response = client.get('/api/sharepoint/invalid_conn/sites')
|
||||
assert response.status_code == 400
|
||||
|
||||
@patch('app.get_or_create_connector')
|
||||
def test_get_files_missing_site_id(self, mock_connector, client):
|
||||
"""Test get_files requires site_id parameter."""
|
||||
mock_connector.return_value = Mock()
|
||||
|
||||
response = client.get('/api/sharepoint/test_conn/files')
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'site_id is required' in data['error']
|
||||
|
||||
@patch('app.get_or_create_connector')
|
||||
def test_read_file_missing_parameters(self, mock_connector, client):
|
||||
"""Test read_file requires site_id and file_path."""
|
||||
mock_connector.return_value = Mock()
|
||||
|
||||
response = client.get('/api/sharepoint/test_conn/read')
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'required' in data['error']
|
||||
|
||||
|
||||
class TestChatAPI:
|
||||
"""Test chat API endpoints."""
|
||||
|
||||
def test_chat_send_missing_parameters(self, client):
|
||||
"""Test chat_send requires all parameters."""
|
||||
response = client.post(
|
||||
'/api/chat/send',
|
||||
data=json.dumps({}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'required' in data['error']
|
||||
|
||||
def test_chat_send_no_document_loaded(self, client):
|
||||
"""Test chat_send fails without loaded document."""
|
||||
response = client.post(
|
||||
'/api/chat/send',
|
||||
data=json.dumps({
|
||||
'site_id': 'test_site',
|
||||
'file_path': 'test.txt',
|
||||
'message': 'Hello'
|
||||
}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'No document loaded' in data['error']
|
||||
|
||||
def test_chat_history_missing_parameters(self, client):
|
||||
"""Test chat_history requires parameters."""
|
||||
response = client.get('/api/chat/history')
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'required' in data['error']
|
||||
|
||||
def test_chat_clear_missing_parameters(self, client):
|
||||
"""Test chat_clear requires parameters."""
|
||||
response = client.post(
|
||||
'/api/chat/clear',
|
||||
data=json.dumps({}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
class TestLLMStatus:
|
||||
"""Test LLM status endpoint."""
|
||||
|
||||
@patch('app.llm_client')
|
||||
def test_llm_status_available(self, mock_llm, client):
|
||||
"""Test LLM status returns available when service is up."""
|
||||
mock_llm.is_available.return_value = True
|
||||
|
||||
response = client.get('/api/llm/status')
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['available'] is True
|
||||
assert 'provider' in data
|
||||
|
||||
@patch('app.llm_client')
|
||||
def test_llm_status_unavailable(self, mock_llm, client):
|
||||
"""Test LLM status returns unavailable when service is down."""
|
||||
mock_llm.is_available.side_effect = Exception("Connection failed")
|
||||
|
||||
response = client.get('/api/llm/status')
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['available'] is False
|
||||
assert 'error' in data
|
||||
|
||||
|
||||
class TestVectorStoreAPI:
|
||||
"""Test vector store / multi-document chat API."""
|
||||
|
||||
@patch('app.vector_store', None)
|
||||
def test_add_document_vector_store_unavailable(self, client):
|
||||
"""Test add_document fails when vector store is unavailable."""
|
||||
response = client.post(
|
||||
'/api/documents/add',
|
||||
data=json.dumps({}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 503
|
||||
data = json.loads(response.data)
|
||||
assert 'Vector store not available' in data['error']
|
||||
|
||||
@patch('app.vector_store', None)
|
||||
def test_list_tags_vector_store_unavailable(self, client):
|
||||
"""Test list_tags fails when vector store is unavailable."""
|
||||
response = client.get('/api/documents/tags')
|
||||
assert response.status_code == 503
|
||||
|
||||
@patch('app.vector_store', None)
|
||||
def test_chat_multi_vector_store_unavailable(self, client):
|
||||
"""Test chat_multi fails when vector store is unavailable."""
|
||||
response = client.post(
|
||||
'/api/chat/multi',
|
||||
data=json.dumps({'message': 'test'}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
class TestIndexingAPI:
|
||||
"""Test background indexing API."""
|
||||
|
||||
@patch('app.vector_store', None)
|
||||
@patch('app.get_or_create_connector')
|
||||
def test_start_indexing_vector_store_unavailable(self, mock_connector, client):
|
||||
"""Test start_indexing fails without vector store."""
|
||||
mock_connector.return_value = Mock()
|
||||
|
||||
response = client.post(
|
||||
'/api/indexing/start',
|
||||
data=json.dumps({}),
|
||||
content_type='application/json'
|
||||
)
|
||||
assert response.status_code == 503
|
||||
|
||||
def test_get_indexing_status_not_found(self, client):
|
||||
"""Test get_indexing_status returns 404 for unknown job."""
|
||||
with patch('app.get_indexer') as mock_indexer:
|
||||
mock_indexer.return_value.get_job_status.return_value = None
|
||||
|
||||
response = client.get('/api/indexing/status/unknown_job')
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestErrorHandlers:
|
||||
"""Test error handlers."""
|
||||
|
||||
def test_404_handler(self, client):
|
||||
"""Test 404 error handler."""
|
||||
response = client.get('/nonexistent-route')
|
||||
assert response.status_code == 404
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert data['error'] == 'Not found'
|
||||
|
||||
|
||||
class TestDecorators:
|
||||
"""Test authentication decorators."""
|
||||
|
||||
def test_require_auth_creates_user_id(self, client):
|
||||
"""Test require_auth decorator creates user_id in session."""
|
||||
with client.session_transaction() as sess:
|
||||
assert 'user_id' not in sess
|
||||
|
||||
response = client.get('/')
|
||||
|
||||
with client.session_transaction() as sess:
|
||||
assert 'user_id' in sess
|
||||
145
test_document_parser.py
Normal file
145
test_document_parser.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Unit tests for document_parser.py
|
||||
"""
|
||||
import pytest
|
||||
from document_parser import DocumentParser, get_file_info
|
||||
|
||||
|
||||
class TestDocumentParser:
|
||||
"""Test DocumentParser class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.parser = DocumentParser()
|
||||
|
||||
def test_can_parse_supported_extensions(self):
|
||||
"""Test can_parse returns True for supported file types."""
|
||||
supported_files = [
|
||||
'document.txt', 'readme.md', 'data.csv', 'config.json',
|
||||
'report.pdf', 'document.docx', 'spreadsheet.xlsx', 'slides.pptx',
|
||||
'script.py', 'code.java', 'style.css', 'index.html'
|
||||
]
|
||||
|
||||
for filename in supported_files:
|
||||
assert self.parser.can_parse(filename), f"Should parse {filename}"
|
||||
|
||||
def test_can_parse_unsupported_extensions(self):
|
||||
"""Test can_parse returns False for unsupported file types."""
|
||||
unsupported_files = [
|
||||
'image.png', 'video.mp4', 'audio.mp3', 'archive.zip',
|
||||
'binary.exe', 'document.doc'
|
||||
]
|
||||
|
||||
for filename in unsupported_files:
|
||||
assert not self.parser.can_parse(filename), f"Should not parse {filename}"
|
||||
|
||||
def test_get_extension(self):
|
||||
"""Test _get_extension method."""
|
||||
assert self.parser._get_extension('file.txt') == '.txt'
|
||||
assert self.parser._get_extension('FILE.TXT') == '.txt'
|
||||
assert self.parser._get_extension('archive.tar.gz') == '.gz'
|
||||
assert self.parser._get_extension('noextension') == ''
|
||||
|
||||
def test_parse_text_utf8(self):
|
||||
"""Test parsing UTF-8 text files."""
|
||||
content = "Hello World\nThis is a test".encode('utf-8')
|
||||
result = self.parser.parse(content, 'test.txt')
|
||||
assert result == "Hello World\nThis is a test"
|
||||
|
||||
def test_parse_text_multiple_encodings(self):
|
||||
"""Test parsing text with different encodings."""
|
||||
content = "Test content"
|
||||
|
||||
# UTF-8
|
||||
result = self.parser._parse_text(content.encode('utf-8'))
|
||||
assert result == "Test content"
|
||||
|
||||
# Latin-1
|
||||
result = self.parser._parse_text(content.encode('latin-1'))
|
||||
assert result == "Test content"
|
||||
|
||||
def test_parse_unsupported_file_raises_error(self):
|
||||
"""Test parsing unsupported file type raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Unsupported file type"):
|
||||
self.parser.parse(b"content", "file.exe")
|
||||
|
||||
def test_parse_json(self):
|
||||
"""Test parsing JSON files."""
|
||||
content = '{"key": "value", "number": 123}'.encode('utf-8')
|
||||
result = self.parser.parse(content, 'data.json')
|
||||
assert '"key": "value"' in result
|
||||
assert '"number": 123' in result
|
||||
|
||||
def test_parse_csv(self):
|
||||
"""Test parsing CSV files."""
|
||||
content = "name,age,city\nAlice,30,NYC\nBob,25,LA".encode('utf-8')
|
||||
result = self.parser.parse(content, 'data.csv')
|
||||
assert "Alice" in result
|
||||
assert "30" in result
|
||||
assert "NYC" in result
|
||||
|
||||
|
||||
class TestGetFileInfo:
|
||||
"""Test get_file_info function."""
|
||||
|
||||
def test_document_category(self):
|
||||
"""Test document file type categorization."""
|
||||
info = get_file_info('report.pdf', 1024)
|
||||
assert info['category'] == 'document'
|
||||
assert info['extension'] == '.pdf'
|
||||
|
||||
def test_spreadsheet_category(self):
|
||||
"""Test spreadsheet file type categorization."""
|
||||
info = get_file_info('data.xlsx', 2048)
|
||||
assert info['category'] == 'spreadsheet'
|
||||
assert info['extension'] == '.xlsx'
|
||||
|
||||
def test_presentation_category(self):
|
||||
"""Test presentation file type categorization."""
|
||||
info = get_file_info('slides.pptx', 4096)
|
||||
assert info['category'] == 'presentation'
|
||||
|
||||
def test_code_category(self):
|
||||
"""Test code file type categorization."""
|
||||
info = get_file_info('script.py', 512)
|
||||
assert info['category'] == 'code'
|
||||
|
||||
def test_image_category(self):
|
||||
"""Test image file type categorization."""
|
||||
info = get_file_info('photo.jpg', 8192)
|
||||
assert info['category'] == 'image'
|
||||
|
||||
def test_file_size_bytes(self):
|
||||
"""Test file size formatting in bytes."""
|
||||
info = get_file_info('small.txt', 512)
|
||||
assert info['size_formatted'] == '512 B'
|
||||
assert info['size_bytes'] == 512
|
||||
|
||||
def test_file_size_kilobytes(self):
|
||||
"""Test file size formatting in KB."""
|
||||
info = get_file_info('medium.txt', 2048)
|
||||
assert 'KB' in info['size_formatted']
|
||||
assert info['size_bytes'] == 2048
|
||||
|
||||
def test_file_size_megabytes(self):
|
||||
"""Test file size formatting in MB."""
|
||||
info = get_file_info('large.pdf', 5 * 1024 * 1024)
|
||||
assert 'MB' in info['size_formatted']
|
||||
assert '5.0' in info['size_formatted']
|
||||
|
||||
def test_file_size_gigabytes(self):
|
||||
"""Test file size formatting in GB."""
|
||||
info = get_file_info('huge.zip', 2 * 1024 * 1024 * 1024)
|
||||
assert 'GB' in info['size_formatted']
|
||||
assert '2.0' in info['size_formatted']
|
||||
|
||||
def test_unknown_extension(self):
|
||||
"""Test unknown file extension."""
|
||||
info = get_file_info('file.xyz', 1024)
|
||||
assert info['category'] == 'file'
|
||||
assert info['extension'] == '.xyz'
|
||||
|
||||
def test_no_extension(self):
|
||||
"""Test file with no extension."""
|
||||
info = get_file_info('README', 1024)
|
||||
assert info['extension'] == ''
|
||||
270
test_llm_client.py
Normal file
270
test_llm_client.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Unit tests for llm_client.py
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from llm_client import (
|
||||
OllamaClient,
|
||||
OpenAIClient,
|
||||
AnthropicClient,
|
||||
create_llm_client
|
||||
)
|
||||
|
||||
|
||||
class TestOllamaClient:
|
||||
"""Test OllamaClient class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.client = OllamaClient(
|
||||
base_url="http://localhost:11434",
|
||||
model="llama3.2",
|
||||
timeout=30
|
||||
)
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test client initialization."""
|
||||
assert self.client.base_url == "http://localhost:11434"
|
||||
assert self.client.model == "llama3.2"
|
||||
assert self.client.timeout == 30
|
||||
|
||||
def test_initialization_strips_trailing_slash(self):
|
||||
"""Test base_url trailing slash is removed."""
|
||||
client = OllamaClient(base_url="http://localhost:11434/")
|
||||
assert client.base_url == "http://localhost:11434"
|
||||
|
||||
@patch('llm_client.requests.post')
|
||||
def test_chat_without_context(self, mock_post):
|
||||
"""Test chat method without context."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"message": {"content": "Hello! How can I help?"}
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
response = self.client.chat(messages)
|
||||
|
||||
assert response == "Hello! How can I help?"
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch('llm_client.requests.post')
|
||||
def test_chat_with_context(self, mock_post):
|
||||
"""Test chat method with context."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"message": {"content": "Based on the document..."}
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
messages = [{"role": "user", "content": "What does it say?"}]
|
||||
context = "This is a test document."
|
||||
response = self.client.chat(messages, context=context)
|
||||
|
||||
assert response == "Based on the document..."
|
||||
|
||||
# Verify system message with context was added
|
||||
call_args = mock_post.call_args
|
||||
sent_messages = call_args[1]['json']['messages']
|
||||
assert sent_messages[0]['role'] == 'system'
|
||||
assert 'test document' in sent_messages[0]['content']
|
||||
|
||||
@patch('llm_client.requests.post')
|
||||
def test_chat_handles_api_error(self, mock_post):
|
||||
"""Test chat handles API errors."""
|
||||
import requests
|
||||
mock_post.side_effect = requests.exceptions.RequestException("Connection failed")
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
|
||||
with pytest.raises(Exception, match="Ollama API error"):
|
||||
self.client.chat(messages)
|
||||
|
||||
@patch('llm_client.requests.get')
|
||||
def test_is_available_success(self, mock_get):
|
||||
"""Test is_available returns True when service is up."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
assert self.client.is_available() is True
|
||||
|
||||
@patch('llm_client.requests.get')
|
||||
def test_is_available_failure(self, mock_get):
|
||||
"""Test is_available returns False when service is down."""
|
||||
mock_get.side_effect = Exception("Connection refused")
|
||||
|
||||
assert self.client.is_available() is False
|
||||
|
||||
|
||||
class TestOpenAIClient:
|
||||
"""Test OpenAIClient class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.client = OpenAIClient(
|
||||
api_key="sk-test123",
|
||||
model="gpt-4",
|
||||
timeout=30
|
||||
)
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test client initialization."""
|
||||
assert self.client.api_key == "sk-test123"
|
||||
assert self.client.model == "gpt-4"
|
||||
assert self.client.timeout == 30
|
||||
|
||||
@patch('llm_client.requests.post')
|
||||
def test_chat_without_context(self, mock_post):
|
||||
"""Test chat method without context."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{"message": {"content": "Hello from GPT-4!"}}
|
||||
]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
response = self.client.chat(messages)
|
||||
|
||||
assert response == "Hello from GPT-4!"
|
||||
|
||||
@patch('llm_client.requests.post')
|
||||
def test_chat_with_context(self, mock_post):
|
||||
"""Test chat method with context."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"choices": [
|
||||
{"message": {"content": "Based on the context..."}}
|
||||
]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
messages = [{"role": "user", "content": "Summarize"}]
|
||||
context = "Test document content"
|
||||
response = self.client.chat(messages, context=context)
|
||||
|
||||
assert response == "Based on the context..."
|
||||
|
||||
@patch('llm_client.requests.get')
|
||||
def test_is_available_success(self, mock_get):
|
||||
"""Test is_available returns True with valid API key."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
assert self.client.is_available() is True
|
||||
|
||||
@patch('llm_client.requests.get')
|
||||
def test_is_available_failure(self, mock_get):
|
||||
"""Test is_available returns False with invalid API key."""
|
||||
mock_get.side_effect = Exception("Unauthorized")
|
||||
|
||||
assert self.client.is_available() is False
|
||||
|
||||
|
||||
class TestAnthropicClient:
|
||||
"""Test AnthropicClient class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.client = AnthropicClient(
|
||||
api_key="sk-ant-test123",
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
timeout=30
|
||||
)
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test client initialization."""
|
||||
assert self.client.api_key == "sk-ant-test123"
|
||||
assert self.client.model == "claude-3-5-sonnet-20241022"
|
||||
assert self.client.timeout == 30
|
||||
|
||||
@patch('llm_client.requests.post')
|
||||
def test_chat_without_context(self, mock_post):
|
||||
"""Test chat method without context."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"content": [{"text": "Hello from Claude!"}]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
response = self.client.chat(messages)
|
||||
|
||||
assert response == "Hello from Claude!"
|
||||
|
||||
@patch('llm_client.requests.post')
|
||||
def test_chat_with_context(self, mock_post):
|
||||
"""Test chat method with context."""
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"content": [{"text": "Based on the document..."}]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
messages = [{"role": "user", "content": "Summarize"}]
|
||||
context = "Test document"
|
||||
response = self.client.chat(messages, context=context)
|
||||
|
||||
assert response == "Based on the document..."
|
||||
|
||||
# Verify system prompt was added
|
||||
call_args = mock_post.call_args
|
||||
sent_payload = call_args[1]['json']
|
||||
assert 'system' in sent_payload
|
||||
assert 'Test document' in sent_payload['system']
|
||||
|
||||
def test_is_available_with_api_key(self):
|
||||
"""Test is_available returns True when API key exists."""
|
||||
assert self.client.is_available() is True
|
||||
|
||||
def test_is_available_without_api_key(self):
|
||||
"""Test is_available returns False without API key."""
|
||||
client = AnthropicClient(api_key="")
|
||||
assert client.is_available() is False
|
||||
|
||||
|
||||
class TestCreateLLMClient:
|
||||
"""Test create_llm_client factory function."""
|
||||
|
||||
def test_create_ollama_client(self):
|
||||
"""Test creating Ollama client."""
|
||||
client = create_llm_client("ollama", model="llama3.2")
|
||||
assert isinstance(client, OllamaClient)
|
||||
assert client.model == "llama3.2"
|
||||
|
||||
def test_create_openai_client(self):
|
||||
"""Test creating OpenAI client."""
|
||||
client = create_llm_client("openai", api_key="sk-test", model="gpt-4")
|
||||
assert isinstance(client, OpenAIClient)
|
||||
assert client.model == "gpt-4"
|
||||
|
||||
def test_create_openai_client_without_key(self):
|
||||
"""Test creating OpenAI client without API key raises error."""
|
||||
with pytest.raises(ValueError, match="OpenAI API key required"):
|
||||
create_llm_client("openai")
|
||||
|
||||
def test_create_anthropic_client(self):
|
||||
"""Test creating Anthropic client."""
|
||||
client = create_llm_client("anthropic", api_key="sk-ant-test")
|
||||
assert isinstance(client, AnthropicClient)
|
||||
|
||||
def test_create_anthropic_client_without_key(self):
|
||||
"""Test creating Anthropic client without API key raises error."""
|
||||
with pytest.raises(ValueError, match="Anthropic API key required"):
|
||||
create_llm_client("anthropic")
|
||||
|
||||
def test_create_unknown_provider(self):
|
||||
"""Test creating client with unknown provider raises error."""
|
||||
with pytest.raises(ValueError, match="Unknown LLM provider"):
|
||||
create_llm_client("unknown_provider")
|
||||
|
||||
@patch.dict('os.environ', {'OLLAMA_URL': 'http://custom:11434', 'OLLAMA_MODEL': 'mistral'})
|
||||
def test_create_ollama_with_env_vars(self):
|
||||
"""Test creating Ollama client with environment variables."""
|
||||
client = create_llm_client("ollama")
|
||||
assert isinstance(client, OllamaClient)
|
||||
assert client.base_url == "http://custom:11434"
|
||||
assert client.model == "mistral"
|
||||
71
toothfairy_client.py
Normal file
71
toothfairy_client.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import requests
|
||||
|
||||
class ToothFairyClient:
|
||||
def __init__(self):
|
||||
self.api_key = os.getenv('TOOTHFAIRYAI_API_KEY')
|
||||
self.workspace_id = os.getenv('TOOTHFAIRYAI_WORKSPACE_ID')
|
||||
self.base_url = os.getenv('TOOTHFAIRYAI_API_URL', 'https://api.toothfairyai.com')
|
||||
|
||||
def _get_headers(self):
|
||||
return {
|
||||
"x-api-key": self.api_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
def create_search_tool(self, ngrok_url):
|
||||
"""Creates the API Tool that tells the agent how to search your DB."""
|
||||
# Corrected endpoint based on the documentation
|
||||
endpoint = f"{self.base_url}/function/create"
|
||||
payload = {
|
||||
"workspaceid": self.workspace_id,
|
||||
"name": "Search_SharePoint_Database",
|
||||
"url": f"{ngrok_url.rstrip('/')}/api/search/chunks",
|
||||
"requestType": "POST",
|
||||
"authorisationType": "none",
|
||||
"description": "Searches the company's internal SharePoint database for document excerpts. Use this whenever the user asks about internal files, reports, or policies.",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"required": True,
|
||||
"description": "The exact search query to find relevant information."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(endpoint, json=payload, headers=self._get_headers())
|
||||
response.raise_for_status()
|
||||
|
||||
# Adjust parsing based on typical wrapper responses if needed, usually data.id or just id
|
||||
response_json = response.json()
|
||||
return response_json.get("data", {}).get("id") or response_json.get("id")
|
||||
|
||||
def create_agent(self, tool_id):
|
||||
"""Creates an Operator agent and attaches your new search tool."""
|
||||
endpoint = f"{self.base_url}/agent/create"
|
||||
payload = {
|
||||
"workspaceid": self.workspace_id,
|
||||
"label": "SharePoint Assistant",
|
||||
"mode": "retriever",
|
||||
"interpolationString": "You are a helpful corporate assistant. Use your Search_SharePoint_Database tool to find answers to user questions. Always cite the 'source_file' in your responses.",
|
||||
"goals": "Answer questions accurately using internal documentation.",
|
||||
"temperature": 0.3,
|
||||
"maxTokens": 2000,
|
||||
"maxHistory": 10, # <-- REQUIRED FIELD
|
||||
"topK": 10, # <-- REQUIRED FIELD
|
||||
"docTopK": 5, # <-- REQUIRED FIELD
|
||||
"hasFunctions": True,
|
||||
"agentFunctions": [tool_id],
|
||||
"agenticRAG": True
|
||||
}
|
||||
|
||||
response = requests.post(endpoint, json=payload, headers=self._get_headers())
|
||||
|
||||
if not response.ok:
|
||||
print(f"❌ Server Error Output: {response.text}")
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
return response_json.get("data", {}).get("id") or response_json.get("id")
|
||||
477
vector_store.py
Normal file
477
vector_store.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
Vector Store for Multi-Document Chat with RAG (Retrieval-Augmented Generation)
|
||||
|
||||
Supports document embeddings, tagging, and semantic search for chatting with
|
||||
multiple documents at once.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentChunk:
|
||||
"""A chunk of a document with metadata."""
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
content: str
|
||||
chunk_index: int
|
||||
embedding: Optional[List[float]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
"""Document metadata with tags."""
|
||||
document_id: str
|
||||
user_id: str
|
||||
site_id: str
|
||||
file_path: str
|
||||
filename: str
|
||||
tags: List[str]
|
||||
content_hash: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
chunk_count: int
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class EmbeddingProvider:
|
||||
"""Base class for embedding providers."""
|
||||
|
||||
def embed_text(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text."""
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts."""
|
||||
return [self.embed_text(text) for text in texts]
|
||||
|
||||
|
||||
class OllamaEmbeddings(EmbeddingProvider):
|
||||
"""Ollama embeddings using local models."""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:11434", model: str = "nomic-embed-text"):
|
||||
"""
|
||||
Initialize Ollama embeddings.
|
||||
|
||||
Args:
|
||||
base_url: Ollama server URL
|
||||
model: Embedding model (e.g., nomic-embed-text, mxbai-embed-large)
|
||||
"""
|
||||
import requests
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.model = model
|
||||
self.requests = requests
|
||||
|
||||
def embed_text(self, text: str) -> List[float]:
|
||||
"""Generate embedding using Ollama."""
|
||||
response = self.requests.post(
|
||||
f"{self.base_url}/api/embeddings",
|
||||
json={
|
||||
"model": self.model,
|
||||
"prompt": text
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["embedding"]
|
||||
|
||||
|
||||
class OpenAIEmbeddings(EmbeddingProvider):
|
||||
"""OpenAI embeddings."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
|
||||
"""
|
||||
Initialize OpenAI embeddings.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Embedding model (text-embedding-3-small or text-embedding-3-large)
|
||||
"""
|
||||
import requests
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.requests = requests
|
||||
|
||||
def embed_text(self, text: str) -> List[float]:
|
||||
"""Generate embedding using OpenAI."""
|
||||
response = self.requests.post(
|
||||
"https://api.openai.com/v1/embeddings",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"input": text,
|
||||
"model": self.model
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"][0]["embedding"]
|
||||
|
||||
|
||||
class InMemoryVectorStore:
|
||||
"""
|
||||
Simple in-memory vector store for development.
|
||||
|
||||
For production, use a proper vector database like:
|
||||
- Pinecone
|
||||
- Weaviate
|
||||
- Qdrant
|
||||
- ChromaDB
|
||||
- pgvector (PostgreSQL extension)
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_provider: EmbeddingProvider):
|
||||
"""Initialize vector store."""
|
||||
self.embedding_provider = embedding_provider
|
||||
self.documents: Dict[str, Document] = {}
|
||||
self.chunks: Dict[str, DocumentChunk] = {}
|
||||
self.user_documents: Dict[str, List[str]] = {} # user_id -> [document_ids]
|
||||
self.tag_index: Dict[str, List[str]] = {} # tag -> [document_ids]
|
||||
|
||||
def add_document(
|
||||
self,
|
||||
user_id: str,
|
||||
site_id: str,
|
||||
file_path: str,
|
||||
filename: str,
|
||||
content: str,
|
||||
tags: List[str],
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Add a document to the vector store with chunking and embeddings.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
site_id: SharePoint site ID
|
||||
file_path: File path
|
||||
filename: Filename
|
||||
content: Document content
|
||||
tags: List of tags (e.g., ["HR", "SALES", "Q4-2024"])
|
||||
chunk_size: Size of each chunk in characters
|
||||
chunk_overlap: Overlap between chunks
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
document_id
|
||||
"""
|
||||
# Generate document ID
|
||||
document_id = self._generate_document_id(user_id, site_id, file_path)
|
||||
|
||||
# Calculate content hash
|
||||
content_hash = hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
# Check if document already exists and hasn't changed
|
||||
if document_id in self.documents:
|
||||
existing_doc = self.documents[document_id]
|
||||
if existing_doc.content_hash == content_hash:
|
||||
# Update tags if different
|
||||
if set(tags) != set(existing_doc.tags):
|
||||
self._update_tags(document_id, existing_doc.tags, tags)
|
||||
existing_doc.tags = tags
|
||||
existing_doc.updated_at = datetime.utcnow().isoformat()
|
||||
return document_id
|
||||
else:
|
||||
# Content changed, remove old chunks
|
||||
self._remove_document_chunks(document_id)
|
||||
|
||||
# Chunk the document
|
||||
chunks = self._chunk_text(content, chunk_size, chunk_overlap)
|
||||
|
||||
# Generate embeddings for all chunks
|
||||
chunk_texts = [chunk for chunk in chunks]
|
||||
embeddings = self.embedding_provider.embed_batch(chunk_texts)
|
||||
|
||||
# Create document chunks
|
||||
document_chunks = []
|
||||
for idx, (chunk_text, embedding) in enumerate(zip(chunk_texts, embeddings)):
|
||||
chunk_id = f"{document_id}_chunk_{idx}"
|
||||
chunk = DocumentChunk(
|
||||
chunk_id=chunk_id,
|
||||
document_id=document_id,
|
||||
content=chunk_text,
|
||||
chunk_index=idx,
|
||||
embedding=embedding,
|
||||
metadata=metadata
|
||||
)
|
||||
self.chunks[chunk_id] = chunk
|
||||
document_chunks.append(chunk)
|
||||
|
||||
# Create document metadata
|
||||
now = datetime.utcnow().isoformat()
|
||||
document = Document(
|
||||
document_id=document_id,
|
||||
user_id=user_id,
|
||||
site_id=site_id,
|
||||
file_path=file_path,
|
||||
filename=filename,
|
||||
tags=tags,
|
||||
content_hash=content_hash,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
chunk_count=len(document_chunks),
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Store document
|
||||
self.documents[document_id] = document
|
||||
|
||||
# Update user index
|
||||
if user_id not in self.user_documents:
|
||||
self.user_documents[user_id] = []
|
||||
if document_id not in self.user_documents[user_id]:
|
||||
self.user_documents[user_id].append(document_id)
|
||||
|
||||
# Update tag index
|
||||
for tag in tags:
|
||||
tag_lower = tag.lower()
|
||||
if tag_lower not in self.tag_index:
|
||||
self.tag_index[tag_lower] = []
|
||||
if document_id not in self.tag_index[tag_lower]:
|
||||
self.tag_index[tag_lower].append(document_id)
|
||||
|
||||
return document_id
|
||||
|
||||
def search(
|
||||
self,
|
||||
user_id: str,
|
||||
query: str,
|
||||
tags: Optional[List[str]] = None,
|
||||
top_k: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for relevant document chunks.
|
||||
|
||||
Args:
|
||||
user_id: User ID (for access control)
|
||||
query: Search query
|
||||
tags: Optional list of tags to filter by
|
||||
top_k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of relevant chunks with similarity scores
|
||||
"""
|
||||
# Get query embedding
|
||||
query_embedding = self.embedding_provider.embed_text(query)
|
||||
|
||||
# Get candidate document IDs
|
||||
if tags:
|
||||
# Filter by tags
|
||||
candidate_doc_ids = set()
|
||||
for tag in tags:
|
||||
tag_lower = tag.lower()
|
||||
if tag_lower in self.tag_index:
|
||||
candidate_doc_ids.update(self.tag_index[tag_lower])
|
||||
# Intersect with user's documents
|
||||
user_doc_ids = set(self.user_documents.get(user_id, []))
|
||||
candidate_doc_ids = candidate_doc_ids.intersection(user_doc_ids)
|
||||
else:
|
||||
# All user's documents
|
||||
candidate_doc_ids = set(self.user_documents.get(user_id, []))
|
||||
|
||||
# Get chunks from candidate documents
|
||||
candidate_chunks = [
|
||||
chunk for chunk in self.chunks.values()
|
||||
if chunk.document_id in candidate_doc_ids
|
||||
]
|
||||
|
||||
# Calculate cosine similarity for each chunk
|
||||
results = []
|
||||
for chunk in candidate_chunks:
|
||||
if chunk.embedding:
|
||||
similarity = self._cosine_similarity(query_embedding, chunk.embedding)
|
||||
results.append({
|
||||
"chunk": chunk,
|
||||
"document": self.documents[chunk.document_id],
|
||||
"similarity": similarity
|
||||
})
|
||||
|
||||
# Sort by similarity and return top_k
|
||||
results.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
def get_documents_by_tags(self, user_id: str, tags: List[str]) -> List[Document]:
|
||||
"""Get all documents with specific tags."""
|
||||
doc_ids = set()
|
||||
for tag in tags:
|
||||
tag_lower = tag.lower()
|
||||
if tag_lower in self.tag_index:
|
||||
doc_ids.update(self.tag_index[tag_lower])
|
||||
|
||||
# Filter by user
|
||||
user_doc_ids = set(self.user_documents.get(user_id, []))
|
||||
doc_ids = doc_ids.intersection(user_doc_ids)
|
||||
|
||||
return [self.documents[doc_id] for doc_id in doc_ids if doc_id in self.documents]
|
||||
|
||||
def list_tags(self, user_id: str) -> Dict[str, int]:
|
||||
"""List all tags for user with document counts."""
|
||||
user_doc_ids = set(self.user_documents.get(user_id, []))
|
||||
tag_counts = {}
|
||||
|
||||
for tag, doc_ids in self.tag_index.items():
|
||||
user_tagged_docs = set(doc_ids).intersection(user_doc_ids)
|
||||
if user_tagged_docs:
|
||||
tag_counts[tag] = len(user_tagged_docs)
|
||||
|
||||
return tag_counts
|
||||
|
||||
def update_document_tags(self, document_id: str, user_id: str, tags: List[str]):
|
||||
"""Update tags for a document."""
|
||||
if document_id not in self.documents:
|
||||
raise ValueError("Document not found")
|
||||
|
||||
doc = self.documents[document_id]
|
||||
if doc.user_id != user_id:
|
||||
raise ValueError("Access denied")
|
||||
|
||||
old_tags = doc.tags
|
||||
self._update_tags(document_id, old_tags, tags)
|
||||
doc.tags = tags
|
||||
doc.updated_at = datetime.utcnow().isoformat()
|
||||
|
||||
def remove_document(self, document_id: str, user_id: str):
|
||||
"""Remove a document and its chunks."""
|
||||
if document_id not in self.documents:
|
||||
return
|
||||
|
||||
doc = self.documents[document_id]
|
||||
if doc.user_id != user_id:
|
||||
raise ValueError("Access denied")
|
||||
|
||||
# Remove from tag index
|
||||
for tag in doc.tags:
|
||||
tag_lower = tag.lower()
|
||||
if tag_lower in self.tag_index:
|
||||
self.tag_index[tag_lower] = [
|
||||
d for d in self.tag_index[tag_lower] if d != document_id
|
||||
]
|
||||
|
||||
# Remove chunks
|
||||
self._remove_document_chunks(document_id)
|
||||
|
||||
# Remove from user index
|
||||
if user_id in self.user_documents:
|
||||
self.user_documents[user_id] = [
|
||||
d for d in self.user_documents[user_id] if d != document_id
|
||||
]
|
||||
|
||||
# Remove document
|
||||
del self.documents[document_id]
|
||||
|
||||
def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
|
||||
"""Split text into overlapping chunks."""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
text_length = len(text)
|
||||
|
||||
while start < text_length:
|
||||
end = start + chunk_size
|
||||
|
||||
# Try to break at sentence boundary
|
||||
if end < text_length:
|
||||
# Look for sentence endings
|
||||
for punct in ['. ', '! ', '? ', '\n\n']:
|
||||
last_punct = text.rfind(punct, start, end)
|
||||
if last_punct != -1:
|
||||
end = last_punct + len(punct)
|
||||
break
|
||||
|
||||
chunk = text[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
start = end - overlap if end < text_length else text_length
|
||||
|
||||
return chunks
|
||||
|
||||
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
import math
|
||||
|
||||
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
||||
magnitude1 = math.sqrt(sum(a * a for a in vec1))
|
||||
magnitude2 = math.sqrt(sum(b * b for b in vec2))
|
||||
|
||||
if magnitude1 == 0 or magnitude2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (magnitude1 * magnitude2)
|
||||
|
||||
def _generate_document_id(self, user_id: str, site_id: str, file_path: str) -> str:
|
||||
"""Generate unique document ID."""
|
||||
combined = f"{user_id}:{site_id}:{file_path}"
|
||||
return hashlib.sha256(combined.encode()).hexdigest()
|
||||
|
||||
def _remove_document_chunks(self, document_id: str):
|
||||
"""Remove all chunks for a document."""
|
||||
chunk_ids_to_remove = [
|
||||
chunk_id for chunk_id, chunk in self.chunks.items()
|
||||
if chunk.document_id == document_id
|
||||
]
|
||||
for chunk_id in chunk_ids_to_remove:
|
||||
del self.chunks[chunk_id]
|
||||
|
||||
def _update_tags(self, document_id: str, old_tags: List[str], new_tags: List[str]):
|
||||
"""Update tag index when tags change."""
|
||||
# Remove from old tags
|
||||
for tag in old_tags:
|
||||
tag_lower = tag.lower()
|
||||
if tag_lower in self.tag_index:
|
||||
self.tag_index[tag_lower] = [
|
||||
d for d in self.tag_index[tag_lower] if d != document_id
|
||||
]
|
||||
|
||||
# Add to new tags
|
||||
for tag in new_tags:
|
||||
tag_lower = tag.lower()
|
||||
if tag_lower not in self.tag_index:
|
||||
self.tag_index[tag_lower] = []
|
||||
if document_id not in self.tag_index[tag_lower]:
|
||||
self.tag_index[tag_lower].append(document_id)
|
||||
|
||||
|
||||
def create_embedding_provider(provider: str = "ollama", **kwargs) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create embedding provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name ("ollama" or "openai")
|
||||
**kwargs: Provider-specific configuration
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
|
||||
Example:
|
||||
# Ollama (default)
|
||||
embeddings = create_embedding_provider("ollama", model="nomic-embed-text")
|
||||
|
||||
# OpenAI
|
||||
embeddings = create_embedding_provider("openai", api_key="sk-...")
|
||||
"""
|
||||
if provider.lower() == "ollama":
|
||||
return OllamaEmbeddings(
|
||||
base_url=kwargs.get("base_url", os.getenv("OLLAMA_URL", "http://localhost:11434")),
|
||||
model=kwargs.get("model", os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text"))
|
||||
)
|
||||
elif provider.lower() == "openai":
|
||||
api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY"))
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key required")
|
||||
return OpenAIEmbeddings(
|
||||
api_key=api_key,
|
||||
model=kwargs.get("model", "text-embedding-3-small")
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding provider: {provider}")
|
||||
652
vector_store_postgres.py
Normal file
652
vector_store_postgres.py
Normal file
@@ -0,0 +1,652 @@
|
||||
"""
|
||||
PostgreSQL Vector Store with pgvector for Multi-Document Chat with RAG
|
||||
|
||||
Persistent vector storage using PostgreSQL with pgvector extension.
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
import psycopg2
|
||||
from psycopg2.extras import RealDictCursor, execute_values
|
||||
from pgvector.psycopg2 import register_vector
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentChunk:
|
||||
"""A chunk of a document with metadata."""
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
content: str
|
||||
chunk_index: int
|
||||
embedding: Optional[List[float]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Document:
|
||||
"""Document metadata with tags."""
|
||||
document_id: str
|
||||
user_id: str
|
||||
site_id: str
|
||||
file_path: str
|
||||
filename: str
|
||||
tags: List[str]
|
||||
content_hash: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
chunk_count: int
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class EmbeddingProvider:
|
||||
"""Base class for embedding providers."""
|
||||
|
||||
def embed_text(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text."""
|
||||
raise NotImplementedError
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts."""
|
||||
return [self.embed_text(text) for text in texts]
|
||||
|
||||
|
||||
class OllamaEmbeddings(EmbeddingProvider):
|
||||
"""Ollama embeddings using local models."""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:11434", model: str = "nomic-embed-text"):
|
||||
"""
|
||||
Initialize Ollama embeddings.
|
||||
|
||||
Args:
|
||||
base_url: Ollama server URL
|
||||
model: Embedding model (e.g., nomic-embed-text, mxbai-embed-large)
|
||||
"""
|
||||
import requests
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.model = model
|
||||
self.requests = requests
|
||||
|
||||
def embed_text(self, text: str) -> List[float]:
|
||||
"""Generate embedding using Ollama."""
|
||||
response = self.requests.post(
|
||||
f"{self.base_url}/api/embeddings",
|
||||
json={
|
||||
"model": self.model,
|
||||
"prompt": text
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["embedding"]
|
||||
|
||||
|
||||
class OpenAIEmbeddings(EmbeddingProvider):
|
||||
"""OpenAI embeddings."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
|
||||
"""
|
||||
Initialize OpenAI embeddings.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Embedding model (text-embedding-3-small or text-embedding-3-large)
|
||||
"""
|
||||
import requests
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.requests = requests
|
||||
|
||||
def embed_text(self, text: str) -> List[float]:
|
||||
"""Generate embedding using OpenAI."""
|
||||
response = self.requests.post(
|
||||
"https://api.openai.com/v1/embeddings",
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"input": text,
|
||||
"model": self.model
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"][0]["embedding"]
|
||||
|
||||
|
||||
class PostgreSQLVectorStore:
|
||||
"""
|
||||
PostgreSQL vector store with pgvector extension.
|
||||
|
||||
Provides persistent storage for document embeddings with tag-based organization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
connection_string: Optional[str] = None,
|
||||
table_prefix: str = ""
|
||||
):
|
||||
"""
|
||||
Initialize PostgreSQL vector store.
|
||||
|
||||
Args:
|
||||
embedding_provider: Provider for generating embeddings
|
||||
connection_string: PostgreSQL connection string (uses env vars if not provided)
|
||||
table_prefix: Prefix for table names (e.g., "dev_" or "prod_")
|
||||
"""
|
||||
self.embedding_provider = embedding_provider
|
||||
self.table_prefix = table_prefix
|
||||
|
||||
# Use connection string or build from env vars
|
||||
if connection_string:
|
||||
self.connection_string = connection_string
|
||||
else:
|
||||
self.connection_string = (
|
||||
f"host={os.getenv('POSTGRES_HOST', 'localhost')} "
|
||||
f"port={os.getenv('POSTGRES_PORT', '5432')} "
|
||||
f"dbname={os.getenv('POSTGRES_DB', 'sharepoint_vectors')} "
|
||||
f"user={os.getenv('POSTGRES_USER', 'postgres')} "
|
||||
f"password={os.getenv('POSTGRES_PASSWORD', 'postgres')}"
|
||||
)
|
||||
|
||||
# Initialize database connection
|
||||
self.conn = psycopg2.connect(self.connection_string)
|
||||
self.conn.autocommit = False
|
||||
|
||||
# Create tables and enable extension if they don't exist
|
||||
self._create_tables()
|
||||
|
||||
# Register pgvector type (must be after extension is enabled)
|
||||
register_vector(self.conn)
|
||||
|
||||
def _create_tables(self):
|
||||
"""Create tables for documents, chunks, and tags."""
|
||||
with self.conn.cursor() as cur:
|
||||
# Enable pgvector extension
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
# Documents table
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_prefix}documents (
|
||||
document_id VARCHAR(64) PRIMARY KEY,
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
site_id VARCHAR(255) NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
filename VARCHAR(255) NOT NULL,
|
||||
content_hash VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL,
|
||||
chunk_count INTEGER NOT NULL,
|
||||
metadata JSONB,
|
||||
UNIQUE(user_id, site_id, file_path)
|
||||
)
|
||||
""")
|
||||
|
||||
# Document chunks table with vector embeddings
|
||||
# Assuming 768 dimensions for nomic-embed-text (adjust based on your model)
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_prefix}document_chunks (
|
||||
chunk_id VARCHAR(128) PRIMARY KEY,
|
||||
document_id VARCHAR(64) NOT NULL REFERENCES {self.table_prefix}documents(document_id) ON DELETE CASCADE,
|
||||
content TEXT NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
embedding vector(768),
|
||||
metadata JSONB,
|
||||
UNIQUE(document_id, chunk_index)
|
||||
)
|
||||
""")
|
||||
|
||||
# Tags table (many-to-many with documents)
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_prefix}document_tags (
|
||||
document_id VARCHAR(64) NOT NULL REFERENCES {self.table_prefix}documents(document_id) ON DELETE CASCADE,
|
||||
tag VARCHAR(100) NOT NULL,
|
||||
PRIMARY KEY(document_id, tag)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for performance
|
||||
cur.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_user_id
|
||||
ON {self.table_prefix}documents(user_id)
|
||||
""")
|
||||
|
||||
cur.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS idx_document_tags_tag
|
||||
ON {self.table_prefix}document_tags(tag)
|
||||
""")
|
||||
|
||||
# Create IVFFlat index for vector similarity search (after some data is inserted)
|
||||
# This is commented out - create it manually after inserting ~1000+ vectors:
|
||||
# CREATE INDEX ON document_chunks USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def add_document(
|
||||
self,
|
||||
user_id: str,
|
||||
site_id: str,
|
||||
file_path: str,
|
||||
filename: str,
|
||||
content: str,
|
||||
tags: List[str],
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Add a document to the vector store with chunking and embeddings.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
site_id: SharePoint site ID
|
||||
file_path: File path
|
||||
filename: Filename
|
||||
content: Document content
|
||||
tags: List of tags (e.g., ["HR", "SALES", "Q4-2024"])
|
||||
chunk_size: Size of each chunk in characters
|
||||
chunk_overlap: Overlap between chunks
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
document_id
|
||||
"""
|
||||
# Generate document ID
|
||||
document_id = self._generate_document_id(user_id, site_id, file_path)
|
||||
|
||||
# Calculate content hash
|
||||
content_hash = hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
with self.conn.cursor() as cur:
|
||||
# Check if document already exists
|
||||
cur.execute(
|
||||
f"SELECT content_hash FROM {self.table_prefix}documents WHERE document_id = %s",
|
||||
(document_id,)
|
||||
)
|
||||
existing = cur.fetchone()
|
||||
|
||||
if existing and existing[0] == content_hash:
|
||||
# Content hasn't changed, just update tags if different
|
||||
cur.execute(
|
||||
f"SELECT tag FROM {self.table_prefix}document_tags WHERE document_id = %s",
|
||||
(document_id,)
|
||||
)
|
||||
existing_tags = [row[0] for row in cur.fetchall()]
|
||||
|
||||
if set(tags) != set(existing_tags):
|
||||
self._update_tags(cur, document_id, tags)
|
||||
cur.execute(
|
||||
f"UPDATE {self.table_prefix}documents SET updated_at = %s WHERE document_id = %s",
|
||||
(datetime.utcnow(), document_id)
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
return document_id
|
||||
|
||||
# Delete old chunks if content changed
|
||||
if existing:
|
||||
cur.execute(
|
||||
f"DELETE FROM {self.table_prefix}document_chunks WHERE document_id = %s",
|
||||
(document_id,)
|
||||
)
|
||||
|
||||
# Chunk the document
|
||||
chunks = self._chunk_text(content, chunk_size, chunk_overlap)
|
||||
|
||||
# Generate embeddings for all chunks
|
||||
embeddings = self.embedding_provider.embed_batch(chunks)
|
||||
|
||||
# Insert or update document
|
||||
now = datetime.utcnow()
|
||||
if existing:
|
||||
cur.execute(f"""
|
||||
UPDATE {self.table_prefix}documents
|
||||
SET content_hash = %s, updated_at = %s, chunk_count = %s, metadata = %s
|
||||
WHERE document_id = %s
|
||||
""", (content_hash, now, len(chunks), psycopg2.extras.Json(metadata), document_id))
|
||||
else:
|
||||
cur.execute(f"""
|
||||
INSERT INTO {self.table_prefix}documents
|
||||
(document_id, user_id, site_id, file_path, filename, content_hash, created_at, updated_at, chunk_count, metadata)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""", (document_id, user_id, site_id, file_path, filename, content_hash, now, now, len(chunks), psycopg2.extras.Json(metadata)))
|
||||
|
||||
# Insert chunks with embeddings
|
||||
chunk_data = [
|
||||
(
|
||||
f"{document_id}_chunk_{idx}",
|
||||
document_id,
|
||||
chunk_text,
|
||||
idx,
|
||||
embedding,
|
||||
psycopg2.extras.Json(metadata)
|
||||
)
|
||||
for idx, (chunk_text, embedding) in enumerate(zip(chunks, embeddings))
|
||||
]
|
||||
|
||||
execute_values(
|
||||
cur,
|
||||
f"""
|
||||
INSERT INTO {self.table_prefix}document_chunks
|
||||
(chunk_id, document_id, content, chunk_index, embedding, metadata)
|
||||
VALUES %s
|
||||
""",
|
||||
chunk_data
|
||||
)
|
||||
|
||||
# Update tags
|
||||
self._update_tags(cur, document_id, tags)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
return document_id
|
||||
|
||||
def search(
|
||||
self,
|
||||
user_id: str,
|
||||
query: str,
|
||||
tags: Optional[List[str]] = None,
|
||||
top_k: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for relevant document chunks using vector similarity.
|
||||
|
||||
Args:
|
||||
user_id: User ID (for access control)
|
||||
query: Search query
|
||||
tags: Optional list of tags to filter by
|
||||
top_k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of relevant chunks with similarity scores
|
||||
"""
|
||||
# Generate query embedding
|
||||
query_embedding = self.embedding_provider.embed_text(query)
|
||||
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||
# Build query based on whether tags are specified
|
||||
if tags:
|
||||
# Filter by tags using JOIN
|
||||
tag_placeholders = ','.join(['%s'] * len(tags))
|
||||
cur.execute(f"""
|
||||
SELECT
|
||||
c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata as chunk_metadata,
|
||||
d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at,
|
||||
d.chunk_count, d.content_hash, d.metadata as doc_metadata,
|
||||
ARRAY_AGG(dt.tag) as tags,
|
||||
1 - (c.embedding <=> %s::vector) as similarity
|
||||
FROM {self.table_prefix}document_chunks c
|
||||
JOIN {self.table_prefix}documents d ON c.document_id = d.document_id
|
||||
JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
|
||||
WHERE d.user_id = %s AND LOWER(dt.tag) IN ({tag_placeholders})
|
||||
GROUP BY c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata,
|
||||
d.user_id, d.site_id, d.file_path, d.filename, d.created_at,
|
||||
d.updated_at, d.chunk_count, d.content_hash, d.metadata, c.embedding
|
||||
ORDER BY similarity DESC
|
||||
LIMIT %s
|
||||
""", [query_embedding, user_id] + [tag.lower() for tag in tags] + [top_k])
|
||||
else:
|
||||
# All user's documents
|
||||
cur.execute(f"""
|
||||
SELECT
|
||||
c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata as chunk_metadata,
|
||||
d.user_id, d.site_id, d.file_path, d.filename, d.created_at, d.updated_at,
|
||||
d.chunk_count, d.content_hash, d.metadata as doc_metadata,
|
||||
ARRAY_AGG(dt.tag) as tags,
|
||||
1 - (c.embedding <=> %s::vector) as similarity
|
||||
FROM {self.table_prefix}document_chunks c
|
||||
JOIN {self.table_prefix}documents d ON c.document_id = d.document_id
|
||||
LEFT JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
|
||||
WHERE d.user_id = %s
|
||||
GROUP BY c.chunk_id, c.document_id, c.content, c.chunk_index, c.metadata,
|
||||
d.user_id, d.site_id, d.file_path, d.filename, d.created_at,
|
||||
d.updated_at, d.chunk_count, d.content_hash, d.metadata, c.embedding
|
||||
ORDER BY similarity DESC
|
||||
LIMIT %s
|
||||
""", [query_embedding, user_id, top_k])
|
||||
|
||||
rows = cur.fetchall()
|
||||
|
||||
# Convert to result format
|
||||
results = []
|
||||
for row in rows:
|
||||
chunk = DocumentChunk(
|
||||
chunk_id=row['chunk_id'],
|
||||
document_id=row['document_id'],
|
||||
content=row['content'],
|
||||
chunk_index=row['chunk_index'],
|
||||
metadata=row['chunk_metadata']
|
||||
)
|
||||
|
||||
document = Document(
|
||||
document_id=row['document_id'],
|
||||
user_id=row['user_id'],
|
||||
site_id=row['site_id'],
|
||||
file_path=row['file_path'],
|
||||
filename=row['filename'],
|
||||
tags=row['tags'] or [],
|
||||
content_hash=row['content_hash'],
|
||||
created_at=row['created_at'].isoformat(),
|
||||
updated_at=row['updated_at'].isoformat(),
|
||||
chunk_count=row['chunk_count'],
|
||||
metadata=row['doc_metadata']
|
||||
)
|
||||
|
||||
results.append({
|
||||
"chunk": chunk,
|
||||
"document": document,
|
||||
"similarity": float(row['similarity'])
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def get_documents_by_tags(self, user_id: str, tags: List[str]) -> List[Document]:
|
||||
"""Get all documents with specific tags."""
|
||||
with self.conn.cursor(cursor_factory=RealDictCursor) as cur:
|
||||
tag_placeholders = ','.join(['%s'] * len(tags))
|
||||
cur.execute(f"""
|
||||
SELECT DISTINCT
|
||||
d.document_id, d.user_id, d.site_id, d.file_path, d.filename,
|
||||
d.content_hash, d.created_at, d.updated_at, d.chunk_count, d.metadata,
|
||||
ARRAY_AGG(dt.tag) as tags
|
||||
FROM {self.table_prefix}documents d
|
||||
JOIN {self.table_prefix}document_tags dt ON d.document_id = dt.document_id
|
||||
WHERE d.user_id = %s AND LOWER(dt.tag) IN ({tag_placeholders})
|
||||
GROUP BY d.document_id, d.user_id, d.site_id, d.file_path, d.filename,
|
||||
d.content_hash, d.created_at, d.updated_at, d.chunk_count, d.metadata
|
||||
""", [user_id] + [tag.lower() for tag in tags])
|
||||
|
||||
rows = cur.fetchall()
|
||||
|
||||
return [
|
||||
Document(
|
||||
document_id=row['document_id'],
|
||||
user_id=row['user_id'],
|
||||
site_id=row['site_id'],
|
||||
file_path=row['file_path'],
|
||||
filename=row['filename'],
|
||||
tags=row['tags'] or [],
|
||||
content_hash=row['content_hash'],
|
||||
created_at=row['created_at'].isoformat(),
|
||||
updated_at=row['updated_at'].isoformat(),
|
||||
chunk_count=row['chunk_count'],
|
||||
metadata=row['metadata']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def list_tags(self, user_id: str) -> Dict[str, int]:
|
||||
"""List all tags for user with document counts."""
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute(f"""
|
||||
SELECT dt.tag, COUNT(DISTINCT dt.document_id) as count
|
||||
FROM {self.table_prefix}document_tags dt
|
||||
JOIN {self.table_prefix}documents d ON dt.document_id = d.document_id
|
||||
WHERE d.user_id = %s
|
||||
GROUP BY dt.tag
|
||||
ORDER BY count DESC, dt.tag
|
||||
""", (user_id,))
|
||||
|
||||
return {row[0]: row[1] for row in cur.fetchall()}
|
||||
|
||||
def get_indexed_sites(self, user_id: str) -> List[str]:
|
||||
"""
|
||||
Get list of site IDs that have indexed documents for this user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of site IDs with indexed documents
|
||||
"""
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute(f"""
|
||||
SELECT DISTINCT site_id
|
||||
FROM {self.table_prefix}documents
|
||||
WHERE user_id = %s
|
||||
ORDER BY site_id
|
||||
""", (user_id,))
|
||||
|
||||
return [row[0] for row in cur.fetchall()]
|
||||
|
||||
def update_document_tags(self, document_id: str, user_id: str, tags: List[str]):
|
||||
"""Update tags for a document."""
|
||||
with self.conn.cursor() as cur:
|
||||
# Verify ownership
|
||||
cur.execute(
|
||||
f"SELECT user_id FROM {self.table_prefix}documents WHERE document_id = %s",
|
||||
(document_id,)
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
raise ValueError("Document not found")
|
||||
|
||||
if row[0] != user_id:
|
||||
raise ValueError("Access denied")
|
||||
|
||||
# Update tags
|
||||
self._update_tags(cur, document_id, tags)
|
||||
|
||||
# Update timestamp
|
||||
cur.execute(
|
||||
f"UPDATE {self.table_prefix}documents SET updated_at = %s WHERE document_id = %s",
|
||||
(datetime.utcnow(), document_id)
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def remove_document(self, document_id: str, user_id: str):
|
||||
"""Remove a document and its chunks."""
|
||||
with self.conn.cursor() as cur:
|
||||
# Verify ownership
|
||||
cur.execute(
|
||||
f"SELECT user_id FROM {self.table_prefix}documents WHERE document_id = %s",
|
||||
(document_id,)
|
||||
)
|
||||
row = cur.fetchone()
|
||||
|
||||
if not row:
|
||||
return
|
||||
|
||||
if row[0] != user_id:
|
||||
raise ValueError("Access denied")
|
||||
|
||||
# Delete document (cascades to chunks and tags)
|
||||
cur.execute(
|
||||
f"DELETE FROM {self.table_prefix}documents WHERE document_id = %s",
|
||||
(document_id,)
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def _update_tags(self, cur, document_id: str, tags: List[str]):
|
||||
"""Update tags for a document within a transaction."""
|
||||
# Delete old tags
|
||||
cur.execute(
|
||||
f"DELETE FROM {self.table_prefix}document_tags WHERE document_id = %s",
|
||||
(document_id,)
|
||||
)
|
||||
|
||||
# Insert new tags
|
||||
if tags:
|
||||
tag_data = [(document_id, tag.lower()) for tag in tags]
|
||||
execute_values(
|
||||
cur,
|
||||
f"INSERT INTO {self.table_prefix}document_tags (document_id, tag) VALUES %s",
|
||||
tag_data
|
||||
)
|
||||
|
||||
def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
|
||||
"""Split text into overlapping chunks."""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
text_length = len(text)
|
||||
|
||||
while start < text_length:
|
||||
end = start + chunk_size
|
||||
|
||||
# Try to break at sentence boundary
|
||||
if end < text_length:
|
||||
for punct in ['. ', '! ', '? ', '\n\n']:
|
||||
last_punct = text.rfind(punct, start, end)
|
||||
if last_punct != -1:
|
||||
end = last_punct + len(punct)
|
||||
break
|
||||
|
||||
chunk = text[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
start = end - overlap if end < text_length else text_length
|
||||
|
||||
return chunks
|
||||
|
||||
def _generate_document_id(self, user_id: str, site_id: str, file_path: str) -> str:
|
||||
"""Generate unique document ID."""
|
||||
combined = f"{user_id}:{site_id}:{file_path}"
|
||||
return hashlib.sha256(combined.encode()).hexdigest()
|
||||
|
||||
def close(self):
|
||||
"""Close database connection."""
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
def create_embedding_provider(provider: str = "ollama", **kwargs) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create embedding provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name ("ollama" or "openai")
|
||||
**kwargs: Provider-specific configuration
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
if provider.lower() == "ollama":
|
||||
return OllamaEmbeddings(
|
||||
base_url=kwargs.get("base_url", os.getenv("OLLAMA_URL", "http://localhost:11434")),
|
||||
model=kwargs.get("model", os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text"))
|
||||
)
|
||||
elif provider.lower() == "openai":
|
||||
api_key = kwargs.get("api_key", os.getenv("OPENAI_API_KEY"))
|
||||
if not api_key:
|
||||
raise ValueError("OpenAI API key required")
|
||||
return OpenAIEmbeddings(
|
||||
api_key=api_key,
|
||||
model=kwargs.get("model", "text-embedding-3-small")
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding provider: {provider}")
|
||||
Reference in New Issue
Block a user