from __future__ import print_function, absolute_import
import re
from llvmlite.llvmpy.core import (Type, Builder, LINKAGE_INTERNAL,
                       Constant, ICMP_EQ)
import llvmlite.llvmpy.core as lc
import llvmlite.binding as ll

from numba import typing, types, cgutils
from numba.utils import cached_property
from numba.targets.base import BaseContext
from numba.targets.callconv import MinimalCallConv
from numba.targets import cmathimpl, operatorimpl
from numba.typing import cmathdecl, operatordecl
from numba.funcdesc import transform_arg_name
from .cudadrv import nvvm
from . import codegen, nvvmutils


# -----------------------------------------------------------------------------
# Typing


class CUDATypingContext(typing.BaseContext):
    def init(self):
        from . import cudadecl, cudamath

        self.install(cudadecl.registry)
        self.install(cudamath.registry)
        self.install(cmathdecl.registry)
        self.install(operatordecl.registry)

# -----------------------------------------------------------------------------
# Implementation

VALID_CHARS = re.compile(r'[^a-z0-9]', re.I)


class CUDATargetContext(BaseContext):
    implement_powi_as_math_call = True
    strict_alignment = True

    # Overrides
    def create_module(self, name):
        return self._internal_codegen._create_empty_module(name)

    def init(self):
        from . import cudaimpl, printimpl, libdevice

        self._internal_codegen = codegen.JITCUDACodegen("numba.cuda.jit")

        self.install_registry(cudaimpl.registry)
        self.install_registry(printimpl.registry)
        self.install_registry(libdevice.registry)
        self.install_registry(cmathimpl.registry)
        self.install_registry(operatorimpl.registry)
        self._target_data = ll.create_target_data(nvvm.default_data_layout)

    def jit_codegen(self):
        return self._internal_codegen

    @property
    def target_data(self):
        return self._target_data

    @cached_property
    def call_conv(self):
        return CUDACallConv(self)

    @classmethod
    def mangle_name(cls, name):
        """
        Mangle the given string
        """
        def repl(m):
            ch = m.group(0)
            return "_%X_" % ord(ch)

        return VALID_CHARS.sub(repl, name)

    def mangler(self, name, argtypes):
        qualified = name + '.' + '.'.join(transform_arg_name(a)
                                          for a in argtypes)
        return self.mangle_name(qualified)

    def prepare_cuda_kernel(self, func, argtypes):
        # Adapt to CUDA LLVM
        module = func.module
        wrapper = self.generate_kernel_wrapper(func, argtypes)
        func.linkage = LINKAGE_INTERNAL
        nvvm.fix_data_layout(module)
        return wrapper

    def generate_kernel_wrapper(self, func, argtypes):
        module = func.module

        arginfo = self.get_arg_packer(argtypes)
        argtys = list(arginfo.argument_types)
        wrapfnty = Type.function(Type.void(), argtys)
        wrapper_module = self.create_module("cuda.kernel.wrapper")
        fnty = Type.function(Type.int(),
                             [self.call_conv.get_return_type(types.pyobject)] + argtys)
        func = wrapper_module.add_function(fnty, name=func.name)
        wrapfn = wrapper_module.add_function(wrapfnty, name="cudaPy_" + func.name)
        builder = Builder.new(wrapfn.append_basic_block(''))

        # Define error handling variables
        def define_error_gv(postfix):
            gv = wrapper_module.add_global_variable(Type.int(),
                                                    name=wrapfn.name + postfix)
            gv.initializer = Constant.null(gv.type.pointee)
            return gv

        gv_exc = define_error_gv("__errcode__")
        gv_tid = []
        gv_ctaid = []
        for i in 'xyz':
            gv_tid.append(define_error_gv("__tid%s__" % i))
            gv_ctaid.append(define_error_gv("__ctaid%s__" % i))

        callargs = arginfo.from_arguments(builder, wrapfn.args)
        status, _ = self.call_conv.call_function(
            builder, func, types.void, argtypes, callargs)

        # Check error status
        with cgutils.if_likely(builder, status.is_ok):
            builder.ret_void()

        with builder.if_then(builder.not_(status.is_python_exc)):
            # User exception raised
            old = Constant.null(gv_exc.type.pointee)

            # Use atomic cmpxchg to prevent rewriting the error status
            # Only the first error is recorded

            casfnty = lc.Type.function(old.type, [gv_exc.type, old.type,
                                                  old.type])

            casfn = wrapper_module.add_function(casfnty,
                                                name="___numba_cas_hack")
            xchg = builder.call(casfn, [gv_exc, old, status.code])
            changed = builder.icmp(ICMP_EQ, xchg, old)

            # If the xchange is successful, save the thread ID.
            sreg = nvvmutils.SRegBuilder(builder)
            with builder.if_then(changed):
                for dim, ptr, in zip("xyz", gv_tid):
                    val = sreg.tid(dim)
                    builder.store(val, ptr)

                for dim, ptr, in zip("xyz", gv_ctaid):
                    val = sreg.ctaid(dim)
                    builder.store(val, ptr)

        builder.ret_void()
        # force inline
        # inline_function(status.code)
        nvvm.set_cuda_kernel(wrapfn)
        module.link_in(ll.parse_assembly(str(wrapper_module)))
        module.verify()

        wrapfn = module.get_function(wrapfn.name)
        return wrapfn

    def make_constant_array(self, builder, typ, ary):
        """
        Return dummy value.

        XXX: We should be able to move cuda.const.array_like into here.
        """

        a = self.make_array(typ)(self, builder)
        return a._getvalue()

    def insert_const_string(self, mod, string):
        """
        Unlike the parent version.  This returns a a pointer in the constant
        addrspace.
        """
        text = Constant.stringz(string)
        name = '.'.join(["__conststring__", self.mangle_name(string)])
        # Try to reuse existing global
        gv = mod.globals.get(name)
        if gv is None:
            # Not defined yet
            gv = mod.add_global_variable(text.type, name=name,
                                         addrspace=nvvm.ADDRSPACE_CONSTANT)
            gv.linkage = LINKAGE_INTERNAL
            gv.global_constant = True
            gv.initializer = text

        # Cast to a i8* pointer
        charty = gv.type.pointee.element
        return Constant.bitcast(gv,
                                charty.as_pointer(nvvm.ADDRSPACE_CONSTANT))

    def insert_string_const_addrspace(self, builder, string):
        """
        Insert a constant string in the constant addresspace and return a
        generic i8 pointer to the data.

        This function attempts to deduplicate.
        """
        lmod = builder.module
        gv = self.insert_const_string(lmod, string)
        return self.insert_addrspace_conv(builder, gv,
                                          nvvm.ADDRSPACE_CONSTANT)

    def insert_addrspace_conv(self, builder, ptr, addrspace):
        """
        Perform addrspace conversion according to the NVVM spec
        """
        lmod = builder.module
        base_type = ptr.type.pointee
        conv = nvvmutils.insert_addrspace_conv(lmod, base_type, addrspace)
        return builder.call(conv, [ptr])

    def optimize_function(self, func):
        """Run O1 function passes
        """
        pass
        ## XXX skipped for now
        # fpm = lp.FunctionPassManager.new(func.module)
        #
        # lp.PassManagerBuilder.new().populate(fpm)
        #
        # fpm.initialize()
        # fpm.run(func)
        # fpm.finalize()


class CUDACallConv(MinimalCallConv):
    pass

