from fastapi import Request
import time
import json
import os
import asyncio
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional
import logging

from utils.db import get_connection

logger = logging.getLogger(__name__)

_API_LOGS_TABLE_READY = False

# Configuration for logging optimization
ENABLE_LOGGING = os.getenv("ENABLE_API_LOGGING", "true").lower() == "true"
LOGGING_SAMPLE_RATE = float(os.getenv("LOGGING_SAMPLE_RATE", "1.0"))  # 1.0 = log all, 0.1 = log 10%
LOG_RETENTION_DAYS = int(os.getenv("LOG_RETENTION_DAYS", "30"))
LOG_BATCH_SIZE = int(os.getenv("LOG_BATCH_SIZE", "10"))
LOG_BATCH_INTERVAL = float(os.getenv("LOG_BATCH_INTERVAL", "2.0"))

# Static file extensions to skip logging
STATIC_EXTENSIONS = {'.js', '.css', '.svg', '.png', '.jpg', '.jpeg', '.gif', '.webp', '.ico', '.woff', '.woff2', '.ttf', '.eot'}

# Paths to skip logging
SKIP_PATHS = {'/health', '/favicon.ico'}

# Background logging queue and worker
_log_queue: Optional[asyncio.Queue] = None
_log_worker_task: Optional[asyncio.Task] = None
_shutdown_event: Optional[asyncio.Event] = None


@dataclass
class LogEntry:
    """Data class for API log entries."""
    timestamp: datetime
    client_ip: str
    method: str
    url: str
    path: str
    query_params: str
    request_body: Optional[str]
    user_agent: str
    accept_language: str
    accept_encoding: str
    referer: str
    x_forwarded_for: str
    status_code: int
    process_time: float
    response_size: int


async def _ensure_api_logs_table_exists():
    global _API_LOGS_TABLE_READY
    if _API_LOGS_TABLE_READY:
        return
    try:
        conn = await get_connection()
        # Check if table exists first to avoid warnings
        await conn.execute("SHOW TABLES LIKE 'api_logs'")
        table_exists = await conn.fetchone()
        
        if not table_exists:
            await conn.execute(
                """
                CREATE TABLE api_logs (
                    id INT AUTO_INCREMENT PRIMARY KEY,
                    timestamp DATETIME NOT NULL,
                    client_ip VARCHAR(45) NOT NULL,
                    method VARCHAR(10) NOT NULL,
                    url TEXT NOT NULL,
                    path VARCHAR(255) NOT NULL,
                    query_params JSON,
                    request_body JSON,
                    user_agent TEXT,
                    accept_language VARCHAR(255),
                    accept_encoding VARCHAR(255),
                    referer TEXT,
                    x_forwarded_for VARCHAR(255),
                    status_code INT NOT NULL,
                    process_time FLOAT NOT NULL,
                    response_size INT NOT NULL,
                    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                    INDEX idx_timestamp (timestamp),
                    INDEX idx_client_ip (client_ip),
                    INDEX idx_method (method),
                    INDEX idx_path (path),
                    INDEX idx_status_code (status_code)
                )
                """
            )
            await conn.commit()
            logger.info("Created api_logs table with indexes")
        _API_LOGS_TABLE_READY = True
    except Exception as e:
        # If there's an error (e.g., no DB connection), log but don't crash
        logger.warning(f"Could not ensure api_logs table exists: {e}")
        logger.warning("API logging will be disabled until database connection is restored")
        # Don't set table as ready if connection failed
        pass


async def _insert_log_batch(log_entries: list[LogEntry]):
    """Insert a batch of log entries into the database."""
    if not log_entries:
        return
    
    try:
        conn = await get_connection()
        
        # Prepare batch insert data
        values = []
        for entry in log_entries:
            values.append((
                entry.timestamp,
                entry.client_ip,
                entry.method,
                entry.url,
                entry.path,
                entry.query_params,
                entry.request_body,
                entry.user_agent,
                entry.accept_language,
                entry.accept_encoding,
                entry.referer,
                entry.x_forwarded_for,
                entry.status_code,
                entry.process_time,
                entry.response_size,
            ))
        
        # Batch insert
        await conn.executemany(
            """
            INSERT INTO api_logs
            (timestamp, client_ip, method, url, path, query_params, request_body, 
             user_agent, accept_language, accept_encoding, referer, x_forwarded_for, 
             status_code, process_time, response_size)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
            """,
            values
        )
        await conn.commit()
        logger.debug(f"Successfully inserted {len(log_entries)} log entries")
    except Exception as e:
        logger.error(f"Error inserting log batch: {e}")


async def _background_log_worker():
    """Background worker that processes log queue and writes to database in batches."""
    logger.info("Started background log worker")
    batch = []
    last_flush_time = time.time()
    
    while True:
        try:
            # Wait for new log entry with timeout
            try:
                log_entry = await asyncio.wait_for(_log_queue.get(), timeout=LOG_BATCH_INTERVAL)
                batch.append(log_entry)
            except asyncio.TimeoutError:
                # Timeout - flush batch if we have entries
                pass
            
            current_time = time.time()
            should_flush = (
                len(batch) >= LOG_BATCH_SIZE or
                (batch and current_time - last_flush_time >= LOG_BATCH_INTERVAL) or
                _shutdown_event.is_set()
            )
            
            if should_flush and batch:
                await _insert_log_batch(batch)
                batch = []
                last_flush_time = current_time
            
            # Check if we should shutdown
            if _shutdown_event.is_set() and _log_queue.empty():
                # Flush any remaining logs
                if batch:
                    await _insert_log_batch(batch)
                logger.info("Background log worker shutting down")
                break
                
        except Exception as e:
            logger.error(f"Error in background log worker: {e}")
            # Continue processing even if one batch fails


async def _cleanup_old_logs():
    """Delete API logs older than retention period."""
    try:
        conn = await get_connection()
        cutoff_date = datetime.now() - timedelta(days=LOG_RETENTION_DAYS)
        
        await conn.execute(
            "DELETE FROM api_logs WHERE timestamp < %s",
            (cutoff_date,)
        )
        deleted_count = conn._last_cursor.rowcount if conn._last_cursor else 0
        await conn.commit()
        
        if deleted_count > 0:
            logger.info(f"Cleaned up {deleted_count} old log entries (older than {LOG_RETENTION_DAYS} days)")
    except Exception as e:
        logger.error(f"Error cleaning up old logs: {e}")


async def _schedule_log_cleanup():
    """Schedule periodic cleanup of old logs (runs daily)."""
    logger.info("Started log cleanup scheduler")
    
    while True:
        try:
            # Run cleanup
            await _cleanup_old_logs()
            
            # Wait 24 hours before next cleanup
            await asyncio.sleep(86400)  # 24 hours
            
            if _shutdown_event.is_set():
                logger.info("Log cleanup scheduler shutting down")
                break
        except Exception as e:
            logger.error(f"Error in log cleanup scheduler: {e}")
            await asyncio.sleep(3600)  # Retry in 1 hour on error


def _should_skip_logging(path: str) -> bool:
    """Check if request path should skip logging."""
    # Check if path is in skip list
    if path in SKIP_PATHS:
        return True
    
    # Check if path has static file extension
    if any(path.endswith(ext) for ext in STATIC_EXTENSIONS):
        return True
    
    # Check if path contains /js/ or /styles.css (served by custom handlers)
    if '/js/' in path or path.endswith('/styles.css'):
        return True
    
    return False


async def mysql_logging_middleware(request: Request, call_next):
    """
    Logs HTTP requests to MySQL table `api_logs` asynchronously.
    
    Features:
    - Asynchronous logging via background queue (no blocking on DB writes)
    - Batch insertion for better performance
    - Skips static files (JS, CSS, images, fonts)
    - Skips health check endpoint
    - Supports sampling for high-traffic scenarios
    - Automatic cleanup of old logs
    """
    path = request.url.path
    
    # Skip logging for static files and health checks
    should_skip = _should_skip_logging(path)
    
    # If logging is disabled or should skip, just pass through
    if not ENABLE_LOGGING or should_skip:
        return await call_next(request)
    
    # Apply sampling if configured
    if LOGGING_SAMPLE_RATE < 1.0:
        import random
        if random.random() > LOGGING_SAMPLE_RATE:
            return await call_next(request)
    
    start_time = time.perf_counter()

    client_ip = request.client.host if request.client else "unknown"
    method = request.method
    url = str(request.url)
    query_params = dict(request.query_params)
    
    # Extract browser and device information from headers
    user_agent = request.headers.get("user-agent", "unknown")
    accept_language = request.headers.get("accept-language", "unknown")
    accept_encoding = request.headers.get("accept-encoding", "unknown")
    referer = request.headers.get("referer", "unknown")
    x_forwarded_for = request.headers.get("x-forwarded-for", "unknown")
    
    # Capture request body for POST/PUT/PATCH requests
    request_body = None
    if method in ["POST", "PUT", "PATCH"]:
        try:
            body = await request.body()
            if body:
                # Try to parse as JSON, fallback to string
                try:
                    request_body = json.loads(body.decode('utf-8'))
                except (json.JSONDecodeError, UnicodeDecodeError):
                    request_body = body.decode('utf-8', errors='replace')
        except Exception:
            request_body = None

    response = await call_next(request)

    process_time = time.perf_counter() - start_time
    status_code = getattr(response, "status_code", None)

    response_size = 0
    try:
        body = getattr(response, "body", None)
        if body is not None:
            response_size = len(body)
    except Exception:
        response_size = 0

    # Create log entry
    log_entry = LogEntry(
        timestamp=datetime.now(),
        client_ip=client_ip,
        method=method,
        url=url,
        path=path,
        query_params=json.dumps(query_params) if query_params else "{}",
        request_body=json.dumps(request_body) if request_body else None,
        user_agent=user_agent,
        accept_language=accept_language,
        accept_encoding=accept_encoding,
        referer=referer,
        x_forwarded_for=x_forwarded_for,
        status_code=status_code,
        process_time=process_time,
        response_size=response_size,
    )

    # Enqueue log entry for async processing (non-blocking)
    try:
        if _log_queue is not None:
            _log_queue.put_nowait(log_entry)
    except Exception as e:
        logger.error(f"Error enqueueing log entry: {e}")

    return response


async def start_logging_system():
    """Initialize the async logging system. Call this on application startup."""
    global _log_queue, _log_worker_task, _shutdown_event
    
    if _log_queue is not None:
        logger.warning("Logging system already started")
        return
    
    # Ensure table exists (will fail gracefully if no DB connection)
    await _ensure_api_logs_table_exists()
    
    # Only start queue if table is ready (i.e., database is accessible)
    if not _API_LOGS_TABLE_READY:
        logger.warning("Database not accessible, logging system will not start")
        return
    
    # Create queue and shutdown event
    _log_queue = asyncio.Queue()
    _shutdown_event = asyncio.Event()
    
    # Start background worker
    _log_worker_task = asyncio.create_task(_background_log_worker())
    
    # Start cleanup scheduler
    asyncio.create_task(_schedule_log_cleanup())
    
    logger.info("Async logging system started")


async def stop_logging_system():
    """Gracefully shutdown the async logging system. Call this on application shutdown."""
    global _log_queue, _log_worker_task, _shutdown_event
    
    if _shutdown_event is None:
        return
    
    logger.info("Shutting down async logging system...")
    
    # Signal shutdown
    _shutdown_event.set()
    
    # Wait for worker to finish processing queue
    if _log_worker_task is not None:
        try:
            await asyncio.wait_for(_log_worker_task, timeout=10.0)
        except asyncio.TimeoutError:
            logger.warning("Log worker did not finish within timeout, cancelling")
            _log_worker_task.cancel()
    
    logger.info("Async logging system stopped")
