#!/usr/bin/env python3
import itertools
import signal
from time import sleep
from pwn import *

class Exploit:
    def __init__(self, filename: str):
        self.filename = filename
        self.patched = "arcadian_patched"
        self.elf = ELF(self.filename)
        self.p = process(self.elf.path)

    def get_flag(self, token: bytes) -> None:
        self.p.recvuntil(token)
        flag = self.p.recvuntil(b"}").decode('utf-8')
        log.success(f"{token.decode('utf-8')} = {flag}")

    def stage1(self) -> None:
        addr = 0x9d28                        # patch off the first check to get first flag
        self.elf.write(addr, b"\x30\xc0")    # xor al,al to pass the test
        self.elf.save(self.patched)
        self.elf = ELF(self.patched)
        self.p = process(self.elf.path)
        self.p.sendline(b"lol lol lol lol")
        self.get_flag(b"TOKEN1: ")

    def stage2(self) -> None:
        # I brute forced this solution with another program
        stage2_solution = b"L R U D L X"
        self.p.recvuntil(b"Enter your 6 moves: >")
        self.p.sendline(stage2_solution)
        self.get_flag(b"TOKEN2: ")
        # Need this for stage3
        gdb.attach(self.p, gdbscript='continue')

    def stage3(self) -> None:
        self.p.recvuntil(b"Decrypting TOKEN3 in 5 seconds")
        sleep(5)
        self.p.send_signal(signal.SIGTRAP)
        log.info("Hit continue on gdb twice!")
        self.p.sendline(b"dump")
        self.get_flag(b"TOKEN3: ")

    def stage4(self) -> None:
        self.elf = ELF(self.patched)
        stage4_lock_offset = self.elf.sym['STAGE4_LOCK']    # Be 31337 and just use pwnt00ls
        self.elf.write(stage4_lock_offset, b"\x00")
        self.elf.save(self.patched)
        self.elf = ELF(self.patched)
        self.p = process([self.elf.path, '--stage', '4'])
        self.p.send(b'\n')
        self.get_flag(b"TOKEN4: ")

    def stage5(self) -> None:
        # The most meme stage of them all lmao easiest one
        cipher = [0xF5,0xE5,0xE4,0xEB,0xD2,0xEB,0xF9,0xEF,0x80,0xDC,0x99,0xD5,0x80,0x84,0x81,0xC9]
        base_key = 0xa5
        results = ""
        for c in cipher:
            results += chr(c ^ base_key)
            base_key += 1
        log.success(f"FLAG = {results}")
        self.p = process([self.elf.path, '--stage', '5'])
        self.p.sendline(results.encode('utf-8'))
        self.p.interactive()

if __name__ == "__main__":
    e = Exploit("./arcadian")
    e.stage1()
    e.stage2()
    e.stage3()
    e.stage4()
    e.stage5()
