/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.gpu;

import com.ibm.cuda.CudaBuffer;
import com.ibm.cuda.CudaDevice;
import com.ibm.cuda.CudaException;
import com.ibm.cuda.CudaGrid;
import com.ibm.cuda.CudaKernel;
import com.ibm.cuda.CudaModule;
import com.ibm.cuda.CudaStream;
import com.ibm.cuda.Dim3;
import com.ibm.gpu.CUDAManager;
import com.ibm.gpu.GPUConfigurationException;
import com.ibm.gpu.GPUSortException;
import com.ibm.gpu.PtxKernelGenerator;
import com.ibm.oti.vm.VM;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;

final class SortNetwork {
    private static final Integer[] powersOf2;
    private static final ConcurrentHashMap<LoadKey, LoadResult> resultsMap;
    private final CudaDevice device;
    private final int maxGridDimX;
    private final CudaKernel sortFirst4;
    private final CudaKernel sortOther1;
    private final CudaKernel sortOther2;
    private final CudaKernel sortOther3;
    private final CudaKernel sortOther4;
    private final CudaKernel sortPhase9;

    private static void checkIndices(int n, int n2, int n3) {
        if (n2 > n3) {
            throw new IllegalArgumentException();
        }
        if (n2 < 0) {
            throw new ArrayIndexOutOfBoundsException(n2);
        }
        if (n3 > n) {
            throw new ArrayIndexOutOfBoundsException(n3);
        }
    }

    private static SortNetwork load(int n, char c) throws GPUConfigurationException {
        LoadKey loadKey = new LoadKey(n, c);
        return resultsMap.computeIfAbsent(loadKey, LoadResult::create).get();
    }

    private static int roundUp(int n, int n2) {
        assert (n > 0);
        assert (n2 > 0);
        int n3 = n % n2;
        return n3 == 0 ? n : n + (n2 - n3);
    }

    private static int significantBits(int n) {
        return 32 - Integer.numberOfLeadingZeros(Math.max(1, n));
    }

    static void sortArray(int n, double[] dArray, int n2, int n3) throws GPUConfigurationException, GPUSortException {
        CUDAManager cUDAManager = SortNetwork.traceStart(n, "double", n2, n3);
        try {
            SortNetwork sortNetwork = SortNetwork.load(n, 'D');
            sortNetwork.sort(dArray, n2, n3);
        }
        catch (GPUConfigurationException | GPUSortException exception) {
            SortNetwork.traceFailure(cUDAManager, exception);
            throw exception;
        }
        SortNetwork.traceSuccess(cUDAManager, n, "double");
    }

    static void sortArray(int n, float[] fArray, int n2, int n3) throws GPUConfigurationException, GPUSortException {
        CUDAManager cUDAManager = SortNetwork.traceStart(n, "float", n2, n3);
        try {
            SortNetwork sortNetwork = SortNetwork.load(n, 'F');
            sortNetwork.sort(fArray, n2, n3);
        }
        catch (GPUConfigurationException | GPUSortException exception) {
            SortNetwork.traceFailure(cUDAManager, exception);
            throw exception;
        }
        SortNetwork.traceSuccess(cUDAManager, n, "float");
    }

    static void sortArray(int n, int[] nArray, int n2, int n3) throws GPUConfigurationException, GPUSortException {
        CUDAManager cUDAManager = SortNetwork.traceStart(n, "int", n2, n3);
        try {
            SortNetwork sortNetwork = SortNetwork.load(n, 'I');
            sortNetwork.sort(nArray, n2, n3);
        }
        catch (GPUConfigurationException | GPUSortException exception) {
            SortNetwork.traceFailure(cUDAManager, exception);
            throw exception;
        }
        SortNetwork.traceSuccess(cUDAManager, n, "int");
    }

    static void sortArray(int n, long[] lArray, int n2, int n3) throws GPUConfigurationException, GPUSortException {
        CUDAManager cUDAManager = SortNetwork.traceStart(n, "long", n2, n3);
        try {
            SortNetwork sortNetwork = SortNetwork.load(n, 'J');
            sortNetwork.sort(lArray, n2, n3);
        }
        catch (GPUConfigurationException | GPUSortException exception) {
            SortNetwork.traceFailure(cUDAManager, exception);
            throw exception;
        }
        SortNetwork.traceSuccess(cUDAManager, n, "long");
    }

    private static void traceFailure(CUDAManager cUDAManager, Exception exception) {
        cUDAManager.outputIfVerbose(exception.getLocalizedMessage());
    }

    private static CUDAManager traceStart(int n, String string, int n2, int n3) {
        CUDAManager cUDAManager = CUDAManager.instanceInternal();
        if (cUDAManager.getVerboseGPUOutput()) {
            cUDAManager.outputIfVerbose("Using device: " + n + " to sort " + string + " array; elements " + n2 + " to " + n3);
        }
        return cUDAManager;
    }

    private static void traceSuccess(CUDAManager cUDAManager, int n, String string) {
        if (cUDAManager.getVerboseGPUOutput()) {
            cUDAManager.outputIfVerbose("Sorted " + string + "s on device " + n + " successfully");
        }
    }

    SortNetwork(CudaDevice cudaDevice, CudaModule cudaModule) throws CudaException {
        this.device = cudaDevice;
        this.maxGridDimX = cudaDevice.getAttribute(5);
        this.sortFirst4 = new CudaKernel(cudaModule, "first4");
        this.sortOther1 = new CudaKernel(cudaModule, "other1");
        this.sortOther2 = new CudaKernel(cudaModule, "other2");
        this.sortOther3 = new CudaKernel(cudaModule, "other3");
        this.sortOther4 = new CudaKernel(cudaModule, "other4");
        this.sortPhase9 = new CudaKernel(cudaModule, "phase9");
    }

    private CudaGrid makeGrid(int n, int n2, CudaStream cudaStream) {
        int n3 = Math.max(1, (n + n2 - 1) / n2);
        return new CudaGrid(this.makeGridDim(n3), new Dim3(n2), cudaStream);
    }

    private Dim3 makeGridDim(int n) {
        int n2 = Math.max(1, n);
        int n3 = 1;
        while (n2 > this.maxGridDimX) {
            if ((n2 & 1) != 0) {
                ++n2;
            }
            n2 >>= 1;
            n3 <<= 1;
        }
        return new Dim3(n2, n3);
    }

    private void sort(double[] dArray, int n, int n2) throws GPUSortException {
        int n3 = n2 - n;
        if (n3 < 2) {
            SortNetwork.checkIndices(dArray.length, n, n2);
            return;
        }
        try (CudaBuffer cudaBuffer = new CudaBuffer(this.device, (long)n3 * 8L);){
            cudaBuffer.copyFrom(dArray, n, n2);
            this.sortBuffer(cudaBuffer, n3);
            cudaBuffer.copyTo(dArray, n, n2);
        }
        catch (CudaException cudaException) {
            throw new GPUSortException(cudaException.getLocalizedMessage(), cudaException);
        }
    }

    private void sort(float[] fArray, int n, int n2) throws GPUSortException {
        int n3 = n2 - n;
        if (n3 < 2) {
            SortNetwork.checkIndices(fArray.length, n, n2);
            return;
        }
        try (CudaBuffer cudaBuffer = new CudaBuffer(this.device, (long)n3 * 4L);){
            cudaBuffer.copyFrom(fArray, n, n2);
            this.sortBuffer(cudaBuffer, n3);
            cudaBuffer.copyTo(fArray, n, n2);
        }
        catch (CudaException cudaException) {
            throw new GPUSortException(cudaException.getLocalizedMessage(), cudaException);
        }
    }

    private void sort(int[] nArray, int n, int n2) throws GPUSortException {
        int n3 = n2 - n;
        if (n3 < 2) {
            SortNetwork.checkIndices(nArray.length, n, n2);
            return;
        }
        try (CudaBuffer cudaBuffer = new CudaBuffer(this.device, (long)n3 * 4L);){
            cudaBuffer.copyFrom(nArray, n, n2);
            this.sortBuffer(cudaBuffer, n3);
            cudaBuffer.copyTo(nArray, n, n2);
        }
        catch (CudaException cudaException) {
            throw new GPUSortException(cudaException.getLocalizedMessage(), cudaException);
        }
    }

    private void sort(long[] lArray, int n, int n2) throws GPUSortException {
        int n3 = n2 - n;
        if (n3 < 2) {
            SortNetwork.checkIndices(lArray.length, n, n2);
            return;
        }
        try (CudaBuffer cudaBuffer = new CudaBuffer(this.device, (long)n3 * 8L);){
            cudaBuffer.copyFrom(lArray, n, n2);
            this.sortBuffer(cudaBuffer, n3);
            cudaBuffer.copyTo(lArray, n, n2);
        }
        catch (CudaException cudaException) {
            throw new GPUSortException(cudaException.getLocalizedMessage(), cudaException);
        }
    }

    private void sortBuffer(CudaBuffer cudaBuffer, int n) throws CudaException {
        try (CudaStream cudaStream = new CudaStream(this.device);){
            Integer n2 = n;
            CudaGrid cudaGrid = this.makeGrid(n >> 1, 256, cudaStream);
            this.sortPhase9.launch(cudaGrid, cudaBuffer, n2);
            int n3 = SortNetwork.significantBits(n - 1);
            if (n3 <= 9) {
                return;
            }
            CudaGrid cudaGrid2 = this.makeGrid(n >> 1, 256, cudaStream);
            block17: for (int i = 9; i < n3; ++i) {
                int n4 = 1 << i;
                int n5 = SortNetwork.roundUp(n, n4);
                CudaGrid cudaGrid3 = this.makeGrid(n5 >> 1, 256, cudaStream);
                this.sortFirst4.launch(cudaGrid3, cudaBuffer, n2, powersOf2[i]);
                n4 = i;
                while ((n4 -= 4) >= 3) {
                    this.sortOther4.launch(cudaGrid2, cudaBuffer, n2, powersOf2[n4]);
                }
                switch (i & 3) {
                    case 2: {
                        this.sortOther3.launch(cudaGrid2, cudaBuffer, n2, powersOf2[2]);
                        continue block17;
                    }
                    case 1: {
                        this.sortOther2.launch(cudaGrid2, cudaBuffer, n2, powersOf2[1]);
                        continue block17;
                    }
                    case 0: {
                        this.sortOther1.launch(cudaGrid2, cudaBuffer, n2, powersOf2[0]);
                        continue block17;
                    }
                }
            }
        }
    }

    static {
        Integer[] integerArray = new Integer[31];
        for (int i = 0; i < 31; ++i) {
            integerArray[i] = 1 << i;
        }
        powersOf2 = integerArray;
        resultsMap = new ConcurrentHashMap();
    }

    private static final class ShutdownHook
    implements Runnable {
        private static final Queue<CudaModule> modules = new ConcurrentLinkedQueue<CudaModule>();

        public static void unloadOnShutdown(CudaModule cudaModule) {
            modules.add(cudaModule);
        }

        private ShutdownHook() {
        }

        @Override
        public void run() {
            CudaModule cudaModule;
            while ((cudaModule = modules.poll()) != null) {
                try {
                    cudaModule.unload();
                }
                catch (CudaException cudaException) {}
            }
        }

        static {
            AccessController.doPrivileged(() -> {
                Thread thread = VM.getVMLangAccess().createThread(new ShutdownHook(), "GPU sort shutdown helper", true, false, false, ClassLoader.getSystemClassLoader());
                Runtime.getRuntime().addShutdownHook(thread);
                return null;
            });
        }
    }

    private static final class LoadResult {
        private final SortNetwork network;
        private final String problem;

        static LoadResult create(LoadKey loadKey) {
            try {
                CudaDevice cudaDevice = new CudaDevice(loadKey.deviceId);
                int n = cudaDevice.getAttribute(75);
                if (n < 2) {
                    return LoadResult.failure("Unsupported device");
                }
                ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(49152);
                PtxKernelGenerator.writeTo(n, loadKey.type, byteArrayOutputStream);
                byteArrayOutputStream.write(0);
                byte[] byArray = byteArrayOutputStream.toByteArray();
                PrivilegedAction<LoadResult> privilegedAction = () -> LoadResult.load(cudaDevice, byArray);
                return AccessController.doPrivileged(privilegedAction);
            }
            catch (CudaException | IOException exception) {
                return LoadResult.failure(exception);
            }
        }

        private static LoadResult failure(Exception exception) {
            return new LoadResult(exception);
        }

        private static LoadResult failure(String string) {
            return new LoadResult(string);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private static LoadResult load(CudaDevice cudaDevice, byte[] byArray) {
            LoadResult loadResult;
            try {
                CudaModule cudaModule = null;
                try {
                    cudaModule = new CudaModule(cudaDevice, byArray);
                    loadResult = LoadResult.success(new SortNetwork(cudaDevice, cudaModule));
                    ShutdownHook.unloadOnShutdown(cudaModule);
                    cudaModule = null;
                }
                finally {
                    if (cudaModule != null) {
                        cudaModule.unload();
                    }
                }
            }
            catch (CudaException cudaException) {
                loadResult = LoadResult.failure(cudaException);
            }
            return loadResult;
        }

        private static LoadResult success(SortNetwork sortNetwork) {
            return new LoadResult(sortNetwork);
        }

        private LoadResult(Exception exception) {
            this(exception.getLocalizedMessage());
        }

        private LoadResult(SortNetwork sortNetwork) {
            this.network = sortNetwork;
            this.problem = null;
        }

        private LoadResult(String string) {
            this.network = null;
            this.problem = string;
        }

        SortNetwork get() throws GPUConfigurationException {
            if (this.problem != null) {
                throw new GPUConfigurationException(this.problem);
            }
            return this.network;
        }
    }

    private static final class LoadKey {
        final int deviceId;
        final char type;

        LoadKey(int n, char c) {
            this.deviceId = n;
            this.type = c;
        }

        public boolean equals(Object object) {
            if (object instanceof LoadKey) {
                LoadKey loadKey = (LoadKey)object;
                if (this.deviceId == loadKey.deviceId && this.type == loadKey.type) {
                    return true;
                }
            }
            return false;
        }

        public int hashCode() {
            return this.deviceId << 4 ^ this.type;
        }
    }
}

