#!/usr/bin/env python3
"""
mkv_verify.py — Markovian Protocol ZK Proof Verifier
Version 1.0 — June 2026

Standalone verifier. No SigmaSynth infrastructure required.
Dependencies: pip install py_ecc numpy

Usage:
    python3 mkv_verify.py proof.json
    python3 mkv_verify.py --test

Verifies that a MarkovProof demonstrates a valid N-step Markov state
transition under the canonical governance matrix M (m_version=1).

The proof is zero-knowledge: it proves the computation was performed
correctly without revealing the input or output state vectors.

Specification: chain.quantsynth.net/zkspec
"""

import hashlib
import json
import sys

import numpy as np
from py_ecc.bn128 import G1, multiply, add, neg, curve_order
from py_ecc.fields import bn128_FQ as FQ

# ── Protocol constants ────────────────────────────────────────────────────────

SCALE_S   = 10**9   # probability values scaled to integers: 0.70 → 700_000_000
M_DENOM   = 20      # M.T entries multiplied by 20 → exact integers
TOLERANCE = 50      # max rounding correction |epsilon[j]|

# H generator: deterministic, derived from protocol seed (not a trusted setup)
_H_SEED = int(hashlib.sha256(b"Markovian-H-generator-v1").hexdigest(), 16) % curve_order
H = multiply(G1, _H_SEED)

# Canonical governance matrix M (m_version=1)
# Row i = transition probabilities FROM state i
# States: 0=Accumulation, 1=Markup, 2=Distribution
CANONICAL_M = {
    1: np.array([
        [0.70, 0.25, 0.05],
        [0.10, 0.75, 0.15],
        [0.20, 0.15, 0.65],
    ], dtype=np.float64)
}

STATES = ["ACCUMULATION", "MARKUP", "DISTRIBUTION"]


# ── EC helpers ────────────────────────────────────────────────────────────────

def _to_ec(p):
    if p is None: return None
    if isinstance(p, (list, tuple)) and len(p) == 2:
        return (FQ(int(p[0])), FQ(int(p[1])))
    return p

def _pt(p): return (FQ(int(p[0])), FQ(int(p[1])))
def _sm(k, P): k = int(k) % curve_order; return None if k == 0 else multiply(P, k)
def _pa(P, Q):
    if P is None: return Q
    if Q is None: return P
    return add(P, Q)
def _ps(P, Q): return _pa(P, neg(Q))
def _hash(*parts): return int(hashlib.sha256(b"||".join(str(p).encode() for p in parts)).hexdigest(), 16) % curve_order


# ── M matrix ─────────────────────────────────────────────────────────────────

def make_m_int(M):
    Mt = M.T
    return [[round(Mt[r, c] * M_DENOM) for c in range(3)] for r in range(3)]


# ── Step verifier ─────────────────────────────────────────────────────────────

def verify_step(M_INT, step, context=b""):
    for j in range(3):
        eps = step["epsilons"][j]
        if abs(eps) > TOLERANCE:
            return False, f"epsilon[{j}]={eps} exceeds tolerance"

        D = _sm(M_DENOM, _to_ec(step["C_out"][j]))
        for k in range(3):
            D = _ps(D, _sm(M_INT[j][k], _to_ec(step["C_in"][k])))
        if eps != 0:
            D = _ps(D, _sm(int(eps) % curve_order, H))

        ctx = context + f"|step|j={j}|eps={eps}".encode()
        proof = step["proofs"][j]
        R_claimed = _to_ec(proof["R"])
        s = int(proof["s"])
        e = int(proof["e"])

        lhs = _pa(_sm(s, G1), _sm(e, D))
        if lhs != R_claimed:
            return False, f"Schnorr check failed for output component j={j}"

        e_check = _hash([int(D[0]), int(D[1])], [int(R_claimed[0]), int(R_claimed[1])], ctx)
        if e_check != e:
            return False, f"Fiat-Shamir challenge mismatch at j={j}"

    return True, "ok"


# ── Main verifier ─────────────────────────────────────────────────────────────

def verify_proof(proof_dict):
    try:
        m_version = int(proof_dict["m_version"])
        N         = int(proof_dict["n_steps"])
        steps     = proof_dict["steps"]

        if m_version not in CANONICAL_M:
            return False, f"Unknown m_version={m_version}"
        if len(steps) != N:
            return False, f"Expected {N} steps, got {len(steps)}"

        M     = CANONICAL_M[m_version]
        M_INT = make_m_int(M)
        ctx_base = f"mkv|v{m_version}|N={N}".encode()

        for i, step in enumerate(steps):
            ctx = ctx_base + f"|i={i}".encode()
            ok, msg = verify_step(M_INT, step, ctx)
            if not ok:
                return False, f"Step {i}: {msg}"
            if i < N - 1 and step["C_out"] != steps[i+1]["C_in"]:
                return False, f"Chain break between step {i} and {i+1}"

        return True, "verified"
    except Exception as e:
        return False, f"Exception: {e}"


# ── Test vectors ──────────────────────────────────────────────────────────────

TEST_VECTORS = [
    {"label": "SPY COVID crash",  "s_in": [0.073496, 1e-6, 0.926502], "expected_regime": "MARKUP"},
    {"label": "QQQ Bull trend",   "s_in": [1e-6, 0.85, 0.149999],     "expected_regime": "MARKUP"},
    {"label": "Accumulation",     "s_in": [0.70, 1e-6, 0.299999],     "expected_regime": "ACCUMULATION"},
    {"label": "Neutral",          "s_in": [0.333, 0.334, 0.333],      "expected_regime": "MARKUP"},
]

def run_tests():
    import time
    try:
        from py_ecc.bn128 import G1
        import secrets
        # Need prove_markov for test generation — import from zk_markov if available
        import importlib.util, os
        spec = importlib.util.spec_from_file_location("zk_markov",
            os.path.expanduser("~/markovian/zk_markov.py"))
        zk = importlib.util.load_from_spec(spec) if spec else None
        if zk:
            spec.loader.exec_module(zk)
    except Exception as e:
        print(f"Test generation requires zk_markov.py: {e}")
        return

    print("Markovian Protocol ZK Verifier — Test Suite")
    print("=" * 60)
    M = CANONICAL_M[1]
    all_pass = True
    for tv in TEST_VECTORS:
        t0 = time.time()
        s_out, proof = zk.prove_markov(M, tv["s_in"], N=2, m_version=1)
        pd = zk.proof_to_dict(proof)
        t_prove = (time.time() - t0) * 1000

        t1 = time.time()
        ok, msg = verify_proof(pd)
        t_verify = (time.time() - t1) * 1000

        regime = STATES[int(np.argmax(s_out))]
        conf   = max(s_out) * 100
        match  = "PASS" if ok and regime == tv["expected_regime"] else "FAIL"
        if match == "FAIL": all_pass = False

        print(f"  [{match}] {tv['label']:<20} → {regime:<14} {conf:.1f}%  "
              f"prove={t_prove:.0f}ms  verify={t_verify:.0f}ms  {msg}")

    print("=" * 60)
    print(f"  Result: {'ALL PASS' if all_pass else 'FAILURES DETECTED'}")


# ── CLI ───────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    if "--test" in sys.argv:
        run_tests()
        sys.exit(0)

    if len(sys.argv) < 2:
        print(__doc__)
        sys.exit(1)

    with open(sys.argv[1]) as f:
        proof_dict = json.load(f)

    ok, msg = verify_proof(proof_dict)
    result = {
        "verified": ok,
        "message":  msg,
        "m_version": proof_dict.get("m_version"),
        "n_steps":   proof_dict.get("n_steps"),
        "type":      proof_dict.get("type"),
    }
    print(json.dumps(result, indent=2))
    sys.exit(0 if ok else 1)
