#!/usr/bin/env python3
"""
SQL Checker — Analyse statique de fichiers SQL (sans connexion a une BDD)
Detecte : erreurs de syntaxe, patterns dangereux, doublons de cle primaire.
"""

import sys
import re

sys.stdout.reconfigure(encoding="utf-8")

def p(*args, **kwargs):
    kwargs["flush"] = True
    print(*args, **kwargs)


# ─── Decoupe manuelle + suivi des numeros de ligne ───────────────────────────

def split_sql(content: str) -> list:
    """
    Retourne une liste de tuples (stmt, line_start) :
      - stmt       : texte de la requete
      - line_start : numero de la 1ere ligne dans le fichier (base 1)
    """
    results = []
    current = []
    in_single = in_double = in_line = in_block = False
    i = 0
    n = len(content)
    line_no = 1          # ligne courante dans le fichier
    stmt_start = 1       # ligne de debut du statement en cours

    while i < n:
        c  = content[i]
        c2 = content[i:i+2]

        if c == "\n":
            line_no += 1

        if in_line:
            if c == "\n": in_line = False
            i += 1; continue

        if in_block:
            if c2 == "*/": in_block = False; i += 2
            else: i += 1
            continue

        if not in_single and not in_double:
            if c2 == "--": in_line = True; i += 2; continue
            if c2 == "/*": in_block = True; i += 2; continue

        if c == "'" and not in_double:
            if in_single and i+1 < n and content[i+1] == "'":
                current.append(c); i += 1
            else:
                in_single = not in_single
        elif c == '"' and not in_single:
            in_double = not in_double
        elif c == ";" and not in_single and not in_double:
            stmt = "".join(current).strip()
            if stmt:
                results.append((stmt, stmt_start))
            current = []
            stmt_start = line_no   # prochaine requete commence ici
            i += 1; continue

        current.append(c)
        i += 1

    last = "".join(current).strip()
    if last:
        results.append((last, stmt_start))
    return results


# ─── Detection doublons de cle primaire ──────────────────────────────────────

_pk_registry: dict = {}

def check_duplicate_pk(raw: str) -> list:
    """
    Retourne une liste de messages d'erreur #1062 avec le numero de ligne
    exact de la valeur dupliquee dans le bloc VALUES.
    """
    global _pk_registry

    m_table = re.match(r"INSERT\s+(?:INTO\s+)?[`\"]?(\w+)[`\"]?", raw, re.IGNORECASE)
    if not m_table:
        return []
    table = m_table.group(1).lower()

    m_cols = re.search(r"\(([^)]+)\)\s+VALUES", raw, re.IGNORECASE)
    if not m_cols:
        return []
    cols = [c.strip().strip("`\"' ") for c in m_cols.group(1).split(",")]
    if not cols:
        return []
    pk_col = cols[0]

    key = (table, pk_col)
    if key not in _pk_registry:
        _pk_registry[key] = set()
    seen = _pk_registry[key]

    values_match = re.search(r"\bVALUES\b([\s\S]+)$", raw, re.IGNORECASE)
    if not values_match:
        return []

    values_start_pos = values_match.start(1)
    values_block     = values_match.group(1)

    # Nombre de lignes avant le bloc VALUES dans la requete
    lines_before_values = raw[:values_start_pos].count("\n")

    errors = []
    for row_match in re.finditer(r"\(([^)]+)\)", values_block):
        raw_vals = [v.strip() for v in row_match.group(1).split(",")]
        if not raw_vals:
            continue
        pk_val_raw = raw_vals[0]
        pk_val = pk_val_raw.strip("`\"'")

        # Ignorer si c'est un nom de colonne (contient des backticks ou guillemets)
        if re.fullmatch(r"[`\"]?\w+[`\"]?", pk_val_raw) and not re.fullmatch(r"\d+", pk_val):
            continue

        # Ligne dans la requete ou se trouve ce tuple
        row_line_in_stmt = values_block[:row_match.start()].count("\n")
        total_line_in_stmt = lines_before_values + row_line_in_stmt + 1

        if pk_val in seen:
            errors.append({
                "pk_val":   pk_val,
                "pk_col":   pk_col,
                "table":    table,
                "rel_line": total_line_in_stmt,  # ligne relative dans la requete
            })
        else:
            seen.add(pk_val)

    return errors


# ─── Regles d'analyse ────────────────────────────────────────────────────────

DANGEROUS = [
    (r"\bDROP\s+DATABASE\b",                    "[DANGER] DROP DATABASE detecte"),
    (r"\bDROP\s+TABLE\b",                       "[ATTENTION] DROP TABLE detecte"),
    (r"\bTRUNCATE\b",                           "[ATTENTION] TRUNCATE detecte"),
    (r"\bDELETE\s+FROM\b(?![\s\S]*\bWHERE\b)", "[ATTENTION] DELETE sans WHERE"),
    (r"^UPDATE\b(?![\s\S]*\bWHERE\b)",          "[ATTENTION] UPDATE sans WHERE"),
]

VALID_STARTERS = {
    "SELECT","INSERT","UPDATE","DELETE","CREATE","DROP","ALTER","TRUNCATE",
    "REPLACE","SET","USE","SHOW","DESCRIBE","EXPLAIN","BEGIN","COMMIT",
    "ROLLBACK","GRANT","REVOKE","CALL","LOCK","UNLOCK","START",
}

def analyse(raw: str, stmt_line: int) -> list:
    """Retourne une liste de { level, msg, line }."""
    issues = []
    upper  = raw.upper()

    first = raw.strip().split()[0].upper() if raw.strip() else ""
    if first not in VALID_STARTERS:
        issues.append({"level": "ERROR", "line": stmt_line,
                       "msg": "Mot-cle SQL inconnu : '{}'".format(first)})

    depth, in_str = 0, False
    i = 0
    while i < len(raw):
        c = raw[i]
        if c == "'" and not in_str:
            in_str = True
        elif c == "'" and in_str:
            if i+1 < len(raw) and raw[i+1] == "'":
                i += 1  # guillemet echappe
            else:
                in_str = False
        elif not in_str:
            if c == "(": depth += 1
            elif c == ")": depth -= 1
        i += 1
    if depth != 0:
        direction = "ouvrante(s)" if depth > 0 else "fermante(s) en trop"
        issues.append({"level": "ERROR", "line": stmt_line,
                       "msg": "Parentheses non equilibrees : {} parenthese(s) {} non fermee(s)".format(
                           abs(depth), direction
                       )})

    cleaned = re.sub(r"\\'", "", raw).replace("''", "")
    if cleaned.count("'") % 2 != 0:
        issues.append({"level": "ERROR", "line": stmt_line,
                       "msg": "Guillemet simple non ferme"})

    for pattern, msg in DANGEROUS:
        if re.search(pattern, upper, re.DOTALL):
            issues.append({"level": "WARN", "line": stmt_line, "msg": msg})

    if first == "INSERT":
        # Detecte un INSERT INTO imbrique hors des chaines '...'
        in_str = False
        for idx in range(len(raw)):
            ch = raw[idx]
            if ch == "'" and not in_str:
                in_str = True
            elif ch == "'" and in_str:
                if idx+1 < len(raw) and raw[idx+1] == "'":
                    pass  # guillemet echappe ''
                else:
                    in_str = False
            elif not in_str and idx > 10:
                if raw[idx:idx+6].upper() == "INSERT":
                    # Calcul de la ligne absolue
                    nested_line = stmt_line + raw[:idx].count("\n")
                    issues.append({
                        "level": "ERROR",
                        "line":  nested_line,
                        "msg":   "#1064 INSERT imbrique ligne {} : un INSERT apparait dans le corps d'un autre INSERT -- point-virgule manquant apres la ligne precedente ?".format(nested_line)
                    })
                    break

        for dup in check_duplicate_pk(raw):
            abs_line = stmt_line + dup["rel_line"] - 1
            issues.append({
                "level": "ERROR",
                "line":  abs_line,
                "msg":   "#1062 Doublon PK : colonne '{}', valeur '{}' dans '{}' (ligne {})".format(
                    dup["pk_col"], dup["pk_val"], dup["table"], abs_line
                )
            })

    return issues


# ─── Programme principal ─────────────────────────────────────────────────────

def check_sql_file(path: str):
    try:
        with open(path, "r", encoding="utf-8") as f:
            content = f.read()
    except FileNotFoundError:
        p("Fichier introuvable : {}".format(path)); sys.exit(1)
    except UnicodeDecodeError:
        try:
            with open(path, "r", encoding="latin-1") as f:
                content = f.read()
            p("[INFO] Fichier lu en encodage latin-1")
        except Exception as e:
            p("Impossible de lire le fichier : {}".format(e)); sys.exit(1)

    p("\nDecoupage du fichier en cours...")
    statements = split_sql(content)
    total = len(statements)

    if total == 0:
        p("Aucune requete SQL detectee."); sys.exit(0)

    p("\n" + "="*65)
    p("  SQL Checker  --  Analyse statique")
    p("  Fichier : {}".format(path))
    p("  {} requete(s) detectee(s)".format(total))
    p("="*65 + "\n")

    errors = []
    warnings = []

    for i, (raw, stmt_line) in enumerate(statements, start=1):
        pct     = round(i / total * 100, 1)
        preview = raw.replace("\n", " ")[:70]
        ellip   = "..." if len(raw) > 70 else ""

        p("[{:>5}/{}] {:>5}%  {}{}".format(i, total, pct, preview, ellip))

        issues = analyse(raw, stmt_line)
        if not issues:
            p("               OK\n")
        else:
            for issue in issues:
                tag = "[ERREUR]" if issue["level"] == "ERROR" else "[WARN]"
                p("               {} ligne {} — {}".format(tag, issue["line"], issue["msg"]))
                if issue["level"] == "ERROR":
                    errors.append((i, issue["line"], preview + ellip, issue["msg"]))
                else:
                    warnings.append((i, issue["line"], preview + ellip, issue["msg"]))
            p()

    p("\n" + "="*65)
    p("  RESUME FINAL")
    p("="*65)
    p("  Requetes analysees : {}".format(total))
    p("  Erreurs            : {}".format(len(errors)))
    p("  Avertissements     : {}".format(len(warnings)))
    p()

    if errors:
        p("ERREURS :")
        for no, line, snip, msg in errors:
            p("  - Requete #{} | Ligne {} : {}".format(no, line, msg))
            p("    {}".format(snip))
        p()

    if warnings:
        p("AVERTISSEMENTS :")
        for no, line, snip, msg in warnings:
            p("  - Requete #{} | Ligne {} : {}".format(no, line, msg))
            p("    {}".format(snip))
        p()

    if not errors and not warnings:
        p("Aucune anomalie detectee -- fichier SQL valide.\n")

    p("="*65 + "\n")
    sys.exit(1 if errors else 0)


if __name__ == "__main__":
    if len(sys.argv) == 3:
        # Mode diagnostic : python sql-checker.py <fichier.sql> <numero_requete>
        try:
            idx = int(sys.argv[2])
        except ValueError:
            p("Le 2e argument doit etre un numero de requete entier.")
            sys.exit(1)
        diagnose_parentheses(sys.argv[1], idx)
    elif len(sys.argv) == 2:
        check_sql_file(sys.argv[1])
    else:
        p("Usage : python sql-checker.py <fichier.sql> [numero_requete]")
        sys.exit(1)


# ─── Mode diagnostic : localise les parentheses non equilibrees ──────────────

def diagnose_parentheses(path: str, stmt_index: int):
    """
    Relit le fichier, isole la requete numero stmt_index (base 1),
    et affiche ligne par ligne l'etat du compteur de parentheses.
    """
    try:
        try:
            with open(path, "r", encoding="utf-8") as f:
                content = f.read()
        except UnicodeDecodeError:
            with open(path, "r", encoding="latin-1") as f:
                content = f.read()
    except FileNotFoundError:
        p("Fichier introuvable : {}".format(path)); sys.exit(1)

    statements = split_sql(content)
    if stmt_index < 1 or stmt_index > len(statements):
        p("Numero de requete invalide (1-{})".format(len(statements))); sys.exit(1)

    raw, stmt_line = statements[stmt_index - 1]

    p("\n" + "="*65)
    p("  DIAGNOSTIC — Requete #{} (debut ligne {})".format(stmt_index, stmt_line))
    p("="*65)

    depth   = 0
    in_str  = False
    problems = []

    lines = raw.split("\n")
    for rel, line in enumerate(lines, start=1):
        abs_line = stmt_line + rel - 1
        open_  = 0
        close_ = 0
        i = 0
        while i < len(line):
            c = line[i]
            if c == "'" and not in_str:
                in_str = True
            elif c == "'" and in_str:
                if i+1 < len(line) and line[i+1] == "'":
                    i += 1          # guillemet echappe ''
                else:
                    in_str = False
            elif not in_str:
                if c == "(":
                    depth += 1; open_ += 1
                elif c == ")":
                    depth -= 1; close_ += 1
                    if depth < 0:
                        problems.append((abs_line, rel, line.strip(), "parenthese fermante sans ouvrante (depth < 0)"))
                        depth = 0
            i += 1

        # Affiche uniquement les lignes qui modifient la profondeur
        if open_ or close_:
            status = "depth={}".format(depth)
            marker = " <-- profondeur negative !" if depth < 0 else ""
            p("  Ligne {:>5} (rel {:>4}) | +{} -{}  {}{}".format(
                abs_line, rel, open_, close_, status, marker))

    p()
    if depth != 0:
        p("  Bilan : {} parenthese(s) {} non fermee(s)".format(
            abs(depth), "ouvrante(s)" if depth > 0 else "fermante(s) en trop"
        ))
    if problems:
        p("\n  Anomalies detectees :")
        for al, rl, txt, reason in problems:
            p("    Ligne {} : {}".format(al, reason))
            p("    >>> {}".format(txt[:100]))
    p("="*65 + "\n")