from pwn import *
from wordle.solver import WordleSolver, WordleContext
import os
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "symbolic_mersenne_cracker"))
from main import Untwister

WORDLIST    = os.path.join(os.path.dirname(__file__), "wordlist.txt")
OBS_ROUNDS  = 2496   # 1248 getrandbits(32) calls = 2 MT cycles → unique solution
WIN_STREAK  = 20

# ── Server pattern → solver emoji ────────────────────────────────────────────
PATTERN_MAP = {"G": "🟩", "Y": "🟨", "B": "⬛"}

def to_solver_feedback(pat5):
    return "".join(PATTERN_MAP[c] for c in pat5)

def parse_pattern(line):
    return line.strip().split(": ", 1)[1]

# ── Constraint builder ────────────────────────────────────────────────────────
#
# Each getrandbits(32) output covers two rounds:
#   odd round  → high 16 bits (bits 31:16 of the 32-bit value)
#   even round → low  16 bits (bits 15:0)
#
# For one 16-bit idx_raw we observe:
#   bits 15:14 → unknown (2 bits, never visible)
#   bits 13:12 → LEAK    (2 bits, sent by server as "LEAK: XX")
#   bit  11    → known unless word_idx in [267, 2047]
#   bits 10:0  → known from word_idx (lower 11 bits identical in both flip cases)

def half_bits(word_idx, leak):
    """16-char '0'/'1'/'?' string for one 16-bit idx_raw."""
    s  = "????"                                    # bits 15:14 always unknown
    #s += f"{(leak >> 1) & 1}{leak & 1}"          # bits 13:12 from LEAK
    if 267 <= word_idx < 2048:
        s += "?"                                 # bit 11 ambiguous
    else:
        s += str((word_idx >> 11) & 1)           # bit 11 known
    for i in range(10, -1, -1):                  # bits 10:0 always known
        s += str((word_idx >> i) & 1)
    return s  # 16 chars

def build_constraint(obs1, obs2):
    """32-char constraint string from two (word_idx, leak) observations."""
    h = half_bits(*obs1) if obs1 else "?" * 16
    l = half_bits(*obs2) if obs2 else "?" * 16
    return h + l

# ── Receive helpers ───────────────────────────────────────────────────────────

def recv_pattern(io):
    """Read lines until a 'Pattern: XXXXX' line arrives. Returns (raw_bytes, pat_str)."""
    while True:
        line = io.recvline()
        if b"Pattern: " in line:
            return line, parse_pattern(line.decode())

def goto_layer(io, layer_num):
    """Read lines until LAYER <layer_num>, capturing LEAK. Returns leak or None."""
    target = f"LAYER {layer_num:04d}"
    leak   = None
    while True:
        line = io.recvline().decode(errors="replace")
        if "LEAK: " in line:
            try:
                leak = int(line.split("LEAK: ")[1].strip(), 2)
            except ValueError:
                pass
        if target in line:
            return leak

# ── Solver setup ──────────────────────────────────────────────────────────────

with open(WORDLIST) as f:
    words       = [w.strip().upper() for w in f]
word_to_idx     = {w: i for i, w in enumerate(words)}
valid_words     = set(words)

print("[*] Loading solver context...")
context = WordleContext(using_original_answer_set=True)
ws      = WordleSolver(context)
print("[*] Solver ready")

def pick_guess():
    guess = ws.best_guess()
    if guess in valid_words:
        return guess
    for idx in ws.game.possible_answers:
        w = context.words[idx]
        if w in valid_words:
            return w
    return guess

# ── Word prediction from cracked RNG ─────────────────────────────────────────

def predict_word(cracked_rng, cache):
    """Mirrors server's get_next_word logic. cache is a shared mutable list."""
    if not cache:
        value = cracked_rng.getrandbits(32)
        print(value)
        cache.extend([(value >> 16) & 0xFFFF, value & 0xFFFF])
    idx_raw = cache.pop(0)
    idx     = idx_raw & 0xFFF
    if idx >= len(words):
        idx ^= 0x800
    return words[idx]

def predict_leak(cracked_rng, cache):
    """Like predict_word but also returns the 2-bit leak."""
    if not cache:
        value = cracked_rng.getrandbits(32)
        cache.extend([(value >> 16) & 0xFFFF, value & 0xFFFF])
    idx_raw = cache.pop(0)
    idx     = idx_raw & 0xFFF
    if idx >= len(words):
        idx ^= 0x800
    return words[idx], (idx_raw >> 12) & 3

# ── One Wordle round ──────────────────────────────────────────────────────────

def play_round(io, round_num):
    """
    Play one Wordle round. Assumes LAYER <round_num> has already been consumed.
    Reads until LAYER <round_num+1> at the end, capturing LEAK.
    Returns (answer_word, leak)  — either may be None on failure.
    """
    ws.new_game("")
    answer_word = None

    guess = pick_guess()
    io.sendline(guess.encode())
    answer, pattern = recv_pattern(io)
    ws.play(guess, to_solver_feedback(pattern))
    print(f"  {guess} -> {pattern}", end="")

    attempt = 1
    while answer != b"Pattern: GGGGG\n" and attempt < 6:
        attempt += 1
        guess = pick_guess()
        io.sendline(guess.encode())
        answer, pattern = recv_pattern(io)
        ws.play(guess, to_solver_feedback(pattern))
        print(f"  {guess} -> {pattern}", end="")

    if answer == b"Pattern: GGGGG\n":
        answer_word = guess
        print(f"  ✓")
        answer_word_idx = word_to_idx[answer_word]
        print(f" Answer Word: {answer_word}, The answer Words index: {answer_word_idx}, Sanity check: {words[answer_word_idx]}")
    else:
        print(f"  ✗")

    leak = goto_layer(io, round_num + 1)
    return answer_word, leak

# ── Main ──────────────────────────────────────────────────────────────────────

io = remote("chall.polygl0ts.ch", 6067)
goto_layer(io, 1)  # skip greeting, land on LAYER 0001

# ── Phase 1: collect observations ────────────────────────────────────────────
print(f"\n[*] Phase 1: collecting {OBS_ROUNDS} rounds of observations")
observations = []  # list of (word_idx, leak) or None

for rnd in range(1, OBS_ROUNDS + 1):
    sys.stdout.write(f"\r  round {rnd}/{OBS_ROUNDS}  ")
    sys.stdout.flush()
    answer_word, leak = play_round(io, rnd)

    if answer_word is not None and leak is not None:
        obs = (word_to_idx[answer_word], leak)
    else:
        obs = None
    observations.append(obs)

good = sum(1 for o in observations if o)
print(f"\n[*] {good}/{OBS_ROUNDS} rounds won  ({OBS_ROUNDS - good} failures → those pairs use all-? bits)")

# ── Phase 2: crack the RNG ────────────────────────────────────────────────────
import copy
from z3 import sat, Not, And

total_obs_rounds = OBS_ROUNDS
print(f"[*] Phase 2: building 624 constraints and running z3")
ut = Untwister()
for i in range(0, OBS_ROUNDS - 1, 2):
    ut.submit(build_constraint(observations[i], observations[i + 1]))

print("[*] Solving...")
cracked, _ = ut.get_random()
print("[*] RNG cracked!")

# Count how many solutions exist (up to 20)
print("[*] Counting solutions...")
n_solutions = 1
while n_solutions < 20:
    if ut.solver.check() != sat:
        break
    model = ut.solver.model()
    ut.solver.add(Not(And([ut.MT[i] == model[ut.MT[i]] for i in range(624)])))
    n_solutions += 1
    print(f"Found solution: {n_solutions}")
print(f"[*] Found {n_solutions}{'+ ' if n_solutions == 20 else ' '} solution(s)")

# ── Verify crack against first 10 getrandbits(32) calls ──────────────────────
# cracked has index=624 (ready to produce call #624 onward).
# To replay calls #0..#9 we need index=0 (same MT words, different start).
raw_st = list(cracked.getstate()[1])   # 625 longs: MT[0..623] + index

version, internal_state, gauss_cache = cracked.getstate()
raw_st[624] = 0                         # reset index → next call = #0
verify_rng   = copy.deepcopy(cracked)
verify_rng.setstate((3, tuple(raw_st), None))
verify_cache = []
print("[*] Verifying cracked state against Phase-1 observations...")
ok = 0; total = 0
for call_i in range(max(10, total_obs_rounds // 2)):
    val      = verify_rng.getrandbits(32)
    print(val)
    high_raw = (val >> 16) & 0xFFFF
    low_raw  = val & 0xFFFF
    for half, raw, obs in [("hi", high_raw, observations[2*call_i]),
                            ("lo", low_raw,  observations[2*call_i+1])]:
        if obs is None:
            continue
        total += 1
        idx = raw & 0xFFF
        if idx >= len(words):
            idx ^= 0x800
        leak = (raw >> 12) & 3
        if idx == obs[0] and leak == obs[1]:
            ok += 1
            #print(f"  call #{call_i} {half}: predicted {words[idx]}(leak={leak}) "
             #     f"actual {words[obs[0]]}(leak={obs[1]}) nice")
        else:
            print(f"  call #{call_i} {half}: predicted {words[idx]}(leak={leak}) "
                  f"actual {words[obs[0]]}(leak={obs[1]}) ✗")
print(f"  {ok}/{total} match  {'✓ crack looks good' if ok == total else '✗ WRONG STATE — constraints are off'}")

# ── Phase 3: 20 consecutive first-guess wins ──────────────────────────────────
print(f"\n[*] Phase 3: predicting and winning {WIN_STREAK} rounds in a row (starting round {total_obs_rounds + 1})")
pred_cache = []


for rnd in range(total_obs_rounds + 1, total_obs_rounds + WIN_STREAK + 1):
    predicted = predict_word(cracked, pred_cache)
    print(f"  [{rnd}] predicting {predicted}")

    io.sendline(predicted.encode())
    answer, pattern = recv_pattern(io)
    print(f"    -> {pattern}")

    if answer != b"Pattern: GGGGG\n":
        print("  [!] wrong prediction — aborting")
        io.interactive()
        sys.exit(1)

    # Consume next LAYER (except after the final win — server goes to enigma)
    if rnd < total_obs_rounds + WIN_STREAK:
        goto_layer(io, rnd + 1)

# ── Final enigma ──────────────────────────────────────────────────────────────
# After 20 wins the server calls get_next_word one more time and asks:
#   "Answer the last enigma: "
# The answer is "LAIN"[leak]  where leak is from that extra get_next_word call.
# pred_cache is empty here (20 rounds = 10 pairs = 10 getrandbits calls consumed).
_, enigma_leak = predict_leak(cracked, pred_cache)
enigma_answer  = "LAIN"[enigma_leak]
print(f"\n[*] Enigma answer: '{enigma_answer}'")

io.recvuntil(b"Answer the last enigma: ")
io.sendline(enigma_answer.encode())
io.interactive()
