# Rekall Memory Forensics
#
# Based on original code by:
# Copyright (C) 2007-2013 Volatility Foundation
#
# Authors:
# Michael Hale Ligh <michael.ligh@mnin.org>
#
# This code:
# Copyright 2014 Google Inc. All Rights Reserved.
#
# Authors:
# Michael Cohen <scudette@google.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at
# your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#

import distorm3
import struct

from rekall import addrspace
from rekall import plugin
from rekall import testlib

from rekall.plugins.windows import common
from rekall.plugins.overlays.windows import pe_vtypes

RIP_INDEX = distorm3.Registers.index("RIP")


class DecodingError(Exception):
    """Raised when unable to decode an instruction."""


class HookHeuristic(object):
    """A Hook heuristic detects possible hooks.

    This heuristic emulates some common CPU instructions to try and detect
    control flow jumps within the first few instructions of a function.

    These are essentially guesses based on the most common hook types. Be aware
    that these are pretty easy to defeat which will cause the hook to be missed.

    See rekall/src/hooks/amd64.asm and rekall/src/hooks/i386.asm For the test
    cases which illustrate the type of hooks that we will detect.
    """

    def __init__(self, session=None):
        self.session = session
        self.Reset()

    def Reset(self):
        # Keep track of registers, stack and main memory.
        self.regs = {}
        self.stack = []
        self.memory = {}

    def WriteToOperand(self, instruction, operand, value):
        if operand.type == "AbsoluteMemory":
            address = self.regs.get(operand.index, 0) + operand.disp
            self.memory[address] = value

        elif operand.type == "AbsoluteMemoryAddress":
            self.memory[operand.disp] = value

        elif operand.type == "Register":
            self.regs[operand.index] = value

        else:
            raise DecodingError(str(instruction))

    def ReadFromOperand(self, operand):
        """Read the operand.

        We support the following forms:

        - Immediate:  JMP 0x123456
        - AbsoluteMemoryAddress: JMP [0x123456]
        - Register: JMP [EAX]
        """
        # Read from register.
        if operand.type == 'Register':
            return self.regs.get(operand.index, 0)

        # Immediate operand.
        if operand.type == 'Immediate':
            return operand.value

        # Read the content of memory.
        if operand.type == "AbsoluteMemory":
            # Register reference e.g. [EBX + 0x10].
            return self._GetMemoryAddress(
                self.regs.get(operand.index, 0) + operand.disp,
                operand.size)

        elif operand.type == "AbsoluteMemoryAddress":
            # Memory reference e.g. [0x100].
            return self._GetMemoryAddress(
                operand.disp, operand.size)

    def _GetMemoryAddress(self, offset, size):
        try:
            # First check our local cache for a previously written value.
            return self.memory[offset]
        except KeyError:
            data = self.address_space.read(offset, size/8)
            format_string = {8: "b", 16: "H", 32: "I", 64: "Q"}[size]

            return struct.unpack(format_string, data)[0]

    def ProcessLEA(self, instruction):
        """Copies the address from the second operand to the first."""
        operand = instruction.operands[1]
        if operand.type == 'AbsoluteMemory':
            address = self.regs.get(operand.index, 0) + operand.disp

        if operand.type == 'AbsoluteMemoryAddress':
            address = operand.disp

        self.WriteToOperand(instruction, instruction.operands[0], address)

    def ProcessPUSH(self, instruction):
        value = self.ReadFromOperand(instruction.operands[0])

        self.stack.append(value)

    def ProcessPOP(self, instruction):
        try:
            value = self.stack.pop(-1)
        except IndexError:
            value = 0

        self.WriteToOperand(instruction, instruction.operands[0], value)

    def ProcessRET(self, _):
        if self.stack:
            return self.stack.pop(-1)

    def ProcessMOV(self, instruction):
        value = self.ReadFromOperand(instruction.operands[1])

        self.WriteToOperand(instruction, instruction.operands[0], value)

    def ProcessINC(self, instruction):
        value = self.ReadFromOperand(instruction.operands[0])

        self.WriteToOperand(instruction, instruction.operands[0], value + 1)

    def ProcessDEC(self, instruction):
        value = self.ReadFromOperand(instruction.operands[0])

        self.WriteToOperand(instruction, instruction.operands[0], value - 1)

    def ProcessCMP(self, instruction):
        """We dont do anything with the comparison since we dont test for it."""
        _ = instruction

    def ProcessTEST(self, instruction):
        """We dont do anything with the comparison since we dont test for it."""
        _ = instruction

    def _Operate(self, instruction, operator):
        value1 = self.ReadFromOperand(instruction.operands[0])

        value2 = self.ReadFromOperand(instruction.operands[1])

        self.WriteToOperand(
            instruction, instruction.operands[0], operator(value1, value2))

    def ProcessXOR(self, instruction):
        return self._Operate(instruction, lambda x, y: x ^ y)

    def ProcessADD(self, instruction):
        return self._Operate(instruction, lambda x, y: x + y)

    def ProcessSUB(self, instruction):
        return self._Operate(instruction, lambda x, y: x - y)

    def ProcessAND(self, instruction):
        return self._Operate(instruction, lambda x, y: x & y)

    def ProcessOR(self, instruction):
        return self._Operate(instruction, lambda x, y: x | y)

    def ProcessSHL(self, instruction):
        return self._Operate(instruction, lambda x, y: x << y)

    def ProcessSHR(self, instruction):
        return self._Operate(instruction, lambda x, y: x >> y)

    def Inspect(self, function, instructions=10):
        """The main entry point to the Hook processor.

        We emulate the function instructions and try to determine the jump
        destination.

        Args:
           function: A basic.Function() instance.
        """
        self.Reset()
        self.address_space = function.obj_vm

        for instruction in function.Decompose(instructions=instructions):
            # For each decoded instruction we update RIP.
            self.regs[RIP_INDEX] = instruction.address + instruction.size
            if instruction.flowControl == "FC_NONE":
                try:
                    handler = getattr(self, "Process%s" % instruction.mnemonic)
                except AttributeError:
                    continue

                # Handle the instruction.
                handler(instruction)

            elif instruction.flowControl == "FC_RET":
                # RET Instruction terminates processing.
                return self.ProcessRET(instruction)

            elif instruction.flowControl == "FC_CALL":
                return self.ReadFromOperand(instruction.operands[0])

            # A JMP instruction.
            elif instruction.flowControl in [
                "FC_UNC_BRANCH", "FC_CND_BRANCH"]:
                return self.ReadFromOperand(instruction.operands[0])


class CheckPEHooks(plugin.VerbosityMixIn, common.WindowsCommandPlugin):
    """Checks a pe file mapped into memory for hooks."""

    name = "check_pehooks"

    @classmethod
    def args(cls, parser):
        super(CheckPEHooks, cls).args(parser)
        parser.add_argument("--image-base", default=0,
                            help="The base address of the pe image in memory.")

        parser.add_argument(
            "--type", default="all", choices=["all", "iat", "inline", "eat"],
            help="Type of hook to display.")

    def __init__(self, image_base=0, type="all", **kwargs):
        super(CheckPEHooks, self).__init__(**kwargs)
        self.image_base = self.session.address_resolver.get_address_by_name(
            image_base)
        self.heuristic = HookHeuristic(session=self.session)
        self.hook_type = type

    def detect_IAT_hooks(self):
        """Detect Import Address Table hooks.

        An IAT hook is where malware changes the IAT entry for a dll after its
        loaded so that when it is called from within the DLL, flow control is
        directed to the malware instead.

        We determine the IAT entry is hooked if the address is outside the dll
        which is imported.
        """
        pe = pe_vtypes.PE(image_base=self.image_base, session=self.session)

        # First try to find all the names of the imported functions.
        imports = [
            (dll, func_name) for dll, func_name, _ in pe.ImportDirectory()]

        resolver = self.session.address_resolver

        for idx, (dll, func_address, _) in enumerate(pe.IAT()):
            target_dll, target_func_name = imports[idx]
            target_dll = self.session.address_resolver.NormalizeModuleName(
                target_dll)

            self.session.report_progress(
                "Checking function %s!%s", target_dll, target_func_name)

            # We only want the containing module.
            _, _, name = resolver.FindContainingModule(func_address)
            if target_dll == name:
                continue

            function_name = "%s:%s" % (target_dll, target_func_name)

            yield function_name, func_address

    def render_iat_hooks(self, renderer):
        renderer.table_header([
            ("Import", "import", "[wrap:60]"),
            ("Dest", "destination", "[addrpad]"),
            ("Dest Name", "dest_name", "[wrap:60]"),
            ])

        for function_name, func_address in self.detect_IAT_hooks():
            destination = self.session.address_resolver.format_address(
                func_address, max_distance=2**64)
            if not destination:
                destination = "%#x" % func_address

            renderer.table_row(
                function_name, func_address, destination)

    def detect_EAT_hooks(self, size=0):
        """Detect Export Address Table hooks.

        An EAT hook is where malware changes the EAT entry for a dll after its
        loaded so that a new DLL wants to link against it, the new DLL will use
        the malware's function instead of the exporting DLL's function.

        We determine the EAT entry is hooked if the address lies outside the
        exporting dll.
        """
        address_space = self.session.GetParameter("default_address_space")
        pe = pe_vtypes.PE(image_base=self.image_base, session=self.session,
                          address_space=address_space)
        start = self.image_base
        end = self.image_base + size

        # If the dll size is not provided we parse it from the PE header.
        if not size:
            for _, _, virtual_address, section_size in pe.Sections():
                # Only count executable sections.
                section_end = self.image_base + virtual_address + section_size
                if section_end > end:
                    end = section_end

        resolver = self.session.address_resolver

        for dll, func, name, hint in pe.ExportDirectory():
            self.session.report_progress("Checking export %s!%s", dll, name)

            # Skip zero or invalid addresses.
            if address_space.read(func.v(), 10) == "\x00" * 10:
                continue

            if start < func.v() < end:
                continue

            function_name = "%s:%s (%s)" % (
                resolver.NormalizeModuleName(dll), name, hint)

            yield function_name, func

    def render_eat_hooks(self, renderer):
        renderer.table_header([
            ("Export", "Export", "[wrap:60]"),
            ("Dest", "destination", "[addrpad]"),
            ("Dest Name", "dest_name", "[wrap:60]"),
            ])

        for function_name, func_address in self.detect_IAT_hooks():
            destination = self.session.address_resolver.format_address(
                func_address, max_distance=2**64)
            if not destination:
                destination = "%#x" % func_address

            renderer.table_row(
                function_name, func_address, destination)

    def detect_inline_hooks(self):
        """A Generator of hooked exported functions from this PE file.

        Yields:
          A tuple of (function, name, jump_destination)
        """
        # Inspect the export directory for inline hooks.
        pe = pe_vtypes.PE(image_base=self.image_base, session=self.session)

        for _, function, name, _ in pe.ExportDirectory():
            self.session.report_progress(
                "Checking function %#x (%s)", function, name)

            # Try to detect an inline hook.
            destination = self.heuristic.Inspect(function, 3) or ""

            # If we did not detect a hook we skip this function.
            if destination:
                yield function, name, destination

    def render_inline_hooks(self, renderer):
        renderer.table_header([("Name", "name", "20s"),
                               ("Hook", "hook", "30s"),
                               ("Disassembly", "location", "60s"),
                               ])
        for function, name, destination in self.detect_inline_hooks():
            hook_detected = False

            # Try to resolve the destination into a name.
            destination_name = self.session.address_resolver.format_address(
                destination, max_distance=2**64)

            # We know about it. We suppress the output for jumps that go into a
            # known module. These should be visible using the regular vad
            # module.
            if destination_name:
                destination = destination_name
            else:
                destination = "%#x" % destination
                hook_detected = True

            # Skip non hooked results if verbosity is too low.
            if self.verbosity < 10 and not hook_detected:
                continue

            # Only highlight results if verbosity is high.
            highlight = ""
            if hook_detected and self.verbosity > 1:
                highlight = "important"

            renderer.table_row(name, destination, function.deref(),
                               highlight=highlight)


    def render(self, renderer):
        if self.hook_type in ["all", "inline"]:
            self.render_inline_hooks(renderer)

        if self.hook_type in ["all", "iat"]:
            self.render_iat_hooks(renderer)

        if self.hook_type in ["all", "eat"]:
            self.render_eat_hooks(renderer)


class IATHooks(plugin.VerbosityMixIn, common.WinProcessFilter):
    """Detect IAT/EAT hooks in process and kernel memory"""

    name = "hooks_iat"

    def render_iat_hooks(self, task, dll, renderer):
        checker = self.session.plugins.check_pehooks(
            image_base=dll.base)

        for function_name, func_address in checker.detect_IAT_hooks():
            destination = self.session.address_resolver.format_address(
                func_address, max_distance=2**64)
            if not destination:
                destination = "%#x" % func_address

            renderer.table_row(
                task.pid, task.name, dll.name, function_name,
                func_address, destination)

    def render(self, renderer):
        cc = self.session.plugins.cc()
        renderer.table_header([
            ("Pid", "pid", "4"),
            ("Proc", "proc", "16"),
            ("Dll", "dll", "30"),
            ("Import", "import", "[wrap:60]"),
            ("Dest", "destination", "[addrpad]"),
            ("Dest Name", "dest_name", "[wrap:60]"),
            ])

        with cc:
            for task in self.filter_processes():
                cc.SwitchProcessContext(task)

                for dll in task.get_load_modules():
                    self.render_iat_hooks(task, dll, renderer)


class TestIATHooks(testlib.SimpleTestCase):
    PLUGIN = "hooks_iat"

    PARAMETERS = dict(
        commandline="hooks_iat --pid %(pid)s"
        )


class EATHooks(plugin.VerbosityMixIn, common.WinProcessFilter):
    """Detect EAT hooks in process and kernel memory"""

    name = "hooks_eat"

    def render_eat_hooks(self, task, dll, renderer):
        checker = self.session.plugins.check_pehooks(
            image_base=dll.base)

        for function_name, func_address in checker.detect_EAT_hooks():
            destination = self.session.address_resolver.format_address(
                func_address, max_distance=2**64)
            if not destination:
                destination = "%#x" % func_address

            renderer.table_row(
                task.pid, task.name, dll.name, function_name,
                func_address, destination)

    def render(self, renderer):
        cc = self.session.plugins.cc()
        renderer.table_header([
            ("Pid", "pid", "4"),
            ("Proc", "proc", "16"),
            ("Dll", "dll", "30"),
            ("Export", "export", "[wrap:60]"),
            ("Dest", "destination", "[addrpad]"),
            ("Dest Name", "dest_name", "[wrap:60]"),
            ])

        with cc:
            for task in self.filter_processes():
                cc.SwitchProcessContext(task)

                for dll in task.get_load_modules():
                    self.render_eat_hooks(task, dll, renderer)


class TestEATHooks(testlib.SimpleTestCase):
    PLUGIN = "hooks_eat"

    PARAMETERS = dict(
        commandline="hooks_eat --pid %(pid)s"
        )



class InlineHooks(plugin.VerbosityMixIn, common.WinProcessFilter):
    """Detect API hooks in process and kernel memory"""

    name = "hooks_inline"

    def render_inline_hooks(self, task, dll, renderer):
        checker = self.session.plugins.check_pehooks(
            image_base=dll.base)

        for function, name, destination in checker.detect_inline_hooks():
            hook_detected = False

            # Try to resolve the destination into a name.
            destination_name = self.session.address_resolver.format_address(
                destination, max_distance=2**64)

            # We know about it. We suppress the output for jumps that go into a
            # known module. These should be visible using the regular vad
            # module.
            if destination_name:
                destination = destination_name
            else:
                destination = "%#x" % destination
                hook_detected = True

            # Skip non hooked results if verbosity is too low.
            if self.verbosity < 10 and not hook_detected:
                continue

            # Only highlight results if verbosity is high.
            highlight = ""
            if hook_detected and self.verbosity > 1:
                highlight = "important"

            renderer.table_row(
                task.pid, task.name, dll.name, name, destination,
                function.deref(), highlight=highlight)


    def render(self, renderer):
        cc = self.session.plugins.cc()
        renderer.table_header([
            ("Pid", "pid", "4"),
            ("Proc", "proc", "16"),
            ("Dll", "dll", "16"),
            ("Name", "name", "[wrap:20]"),
            ("Hook", "hook", "20"),
            ("Disassembly", "location", "60"),
            ])

        with cc:
            for task in self.filter_processes():
                cc.SwitchProcessContext(task)

                for dll in task.get_load_modules():
                    self.render_inline_hooks(task, dll, renderer)


class TestInlineHooks(testlib.SimpleTestCase):
    PLUGIN = "hooks_inline"

    PARAMETERS = dict(
        commandline="hooks_inline --pid %(pid)s"
        )


class TestHookHeuristics(testlib.RekallBaseUnitTestCase):
    """Test the hook detection heuristic.

    The actual test cases are generated using the nasm assembler in:

    rekall/src/hooks/amd64.asm and rekall/src/hooks/i386.asm
    """
    PLUGIN = "apihooks"

    def testHook(self):
        session = self.MakeUserSession()

        # The target address should be fixed at this offset.
        target = 0x100

        heuristic = HookHeuristic(session=session)

        profile = session.LoadProfile("tests/hooks")
        arch = session.profile.metadata("arch")
        for test_case in profile.data[arch]:
            offset = test_case["offset"]
            # Test case data is the assembly snippet mapped at the specified
            # offset in the address space.
            address_space = addrspace.BufferAddressSpace(
                data=test_case["data"].decode("base64"),
                session=session, base_offset=offset)

            function = session.profile.Function(
                offset=offset, vm=address_space, name=test_case["name"])

            # Detect the jump in this function
            destination = heuristic.Inspect(function)

            # All hooks in test cases go to the same target offset (0x100).
            self.assertEqual(destination, target)
