/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.gpu;

import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;

final class PtxKernelGenerator {
    private final String maxValue;
    private int stepCount;
    private final boolean typeIsScalar;
    private final String typeName;
    private final int typeSize;
    private final OutputStreamWriter writer;

    public static void writeTo(int n, char c, OutputStream outputStream) throws IOException {
        PtxKernelGenerator ptxKernelGenerator = new PtxKernelGenerator(outputStream, c);
        ptxKernelGenerator.generate(n);
    }

    private PtxKernelGenerator(OutputStream outputStream, char c) {
        switch (c) {
            case 'D': {
                this.typeIsScalar = false;
                this.typeName = ".f64";
                this.typeSize = 8;
                this.maxValue = "0dFFF8000000000000";
                break;
            }
            case 'F': {
                this.typeIsScalar = false;
                this.typeName = ".f32";
                this.typeSize = 4;
                this.maxValue = "0f7FFFFFFF";
                break;
            }
            case 'I': {
                this.typeIsScalar = true;
                this.typeName = ".s32";
                this.typeSize = 4;
                this.maxValue = "0x" + Integer.toHexString(Integer.MAX_VALUE);
                break;
            }
            case 'J': {
                this.typeIsScalar = true;
                this.typeName = ".s64";
                this.typeSize = 8;
                this.maxValue = "0x" + Long.toHexString(Long.MAX_VALUE);
                break;
            }
            default: {
                throw new IllegalArgumentException(String.valueOf(c));
            }
        }
        this.stepCount = 0;
        this.writer = new OutputStreamWriter(outputStream, StandardCharsets.US_ASCII);
    }

    private void append(String string) throws IOException {
        this.writer.append(string).append('\n');
    }

    private void compare(int n, int n2) throws IOException {
        if (this.typeIsScalar) {
            this.format("setp.gt%s p0,vl%d,vl%d;", this.typeName, n, n2);
        } else {
            this.format("testp.number%s p1,vl%d;", this.typeName, n2);
            this.format("setp.gtu.and%s p0,vl%d,vl%d,p1;", this.typeName, n, n2);
            this.format("@!p0 mov.b%d vs0,vl%d;", this.typeSize * 8, n);
            this.format("@!p0 mov.b%d vs1,vl%d;", this.typeSize * 8, n2);
            this.format("@!p0 setp.eq%s p1,vl%d,vl%d;", this.typeName, n, n2);
            this.format("@!p0 setp.gt.and.s%d p0,vs0,vs1,p1;", this.typeSize * 8);
        }
    }

    private void compareAndSwap(int n, int n2) throws IOException {
        this.compare(n, n2);
        this.format("@p0 mov%s tmp,vl%d;", this.typeName, n);
        this.format("@p0 mov%s vl%d,vl%d;", this.typeName, n, n2);
        this.format("@p0 mov%s vl%d,tmp;", this.typeName, n2);
        this.format("@p0 or.b32 moved,moved,%s;", this.constant(1 << n | 1 << n2));
    }

    private void computeIndices(boolean bl) throws IOException {
        assert (1 <= this.stepCount && this.stepCount <= 5);
        int n = 1 << this.stepCount;
        String string = "stride";
        if (this.stepCount != 1) {
            string = "step";
            this.format("shr.u32 step,stride,%d;", this.stepCount - 1);
        }
        this.format("sub.s32 mask,%s,1;", string);
        this.append("mov.u32 rt0,%nctaid.x;");
        this.append("mov.u32 rt1,%ctaid.y;");
        this.append("mov.u32 rt2,%ctaid.x;");
        this.append("mad.lo.u32 threadId,rt0,rt1,rt2;");
        this.append("mov.u32 rt0,%ntid.x;");
        this.append("mov.u32 rt1,%tid.x;");
        this.append("mad.lo.u32 threadId,threadId,rt0,rt1;");
        this.append("not.b32 rt0,mask;");
        this.append("and.b32 rt0,rt0,threadId;");
        this.append("and.b32 rt1,threadId,mask;");
        this.format("mad.lo.u32 ix0,rt0,%d,rt1;", n);
        if (bl) {
            int n2 = n >> 1;
            int n3 = 0;
            while (++n3 < n2) {
                this.format("add.u32 ix%d,ix%d,%s;", n3, n3 - 1, string);
            }
            this.append("mad.lo.u32 rt0,stride,2,-1;");
            this.format("xor.b32 ix%d,ix%d,rt0;", n2, n2 - 1);
            n3 = n2;
            while (++n3 < n) {
                this.format("add.u32 ix%d,ix%d,%s;", n3, n3 - 1, string);
            }
        } else {
            int n4 = 0;
            while (++n4 < n) {
                this.format("add.u32 ix%d,ix%d,%s;", n4, n4 - 1, string);
            }
        }
    }

    private String constant(int n) {
        String string;
        switch (this.stepCount) {
            case 1: {
                string = "%d";
                break;
            }
            case 2: {
                string = "0x%x";
                break;
            }
            case 3: {
                string = "0x%02x";
                break;
            }
            default: {
                string = "0x%04x";
            }
        }
        return String.format(string, n);
    }

    private void declareLocals() throws IOException {
        int n = 1 << this.stepCount;
        this.append(".reg .u64 data;");
        this.append(".reg .u32 length;");
        this.append(".reg .u32 stride;");
        this.append(".reg .u32 threadId;");
        this.append(".reg .u32 mask;");
        if (this.stepCount != 1) {
            this.append(".reg .u32 step;");
        }
        this.format(".reg %s tmp;", this.typeName);
        this.append(".reg .b32 moved;");
        this.append(".reg .b32 bit;");
        this.append(".reg .u32 rt<3>;");
        this.format(".reg .u32 ix<%d>;", n);
        this.format(".reg %s vl<%d>;", this.typeName, n);
        this.append(".reg .pred p<2>;");
        this.format(".reg .s%d vs<2>;", this.typeSize * 8);
        this.append(".reg .u64 ptr;");
    }

    private void emitFirstPhases() throws IOException {
        this.append(".visible .entry");
        this.format("phase%d(.param .u64 _data,.param .u32 _length)", 9);
        this.append(".maxntid 256,1,1");
        this.append("{");
        this.format(".shared .align %d %s _sharedData[%d];", this.typeSize, this.typeName, 512);
        this.append(".reg .u64 data;");
        this.append(".reg .u32 length;");
        this.append(".reg .u64 sharedData;");
        this.append(".reg .u64 dataPtr;");
        this.append(".reg .u64 sharedPtr<2>;");
        this.append(".reg .u32 baseIndex;");
        this.append(".reg .u32 blockDimX;");
        this.append(".reg .u32 globalIndex;");
        this.append(".reg .u32 workId;");
        this.append(".reg .pred p<2>;");
        this.format(".reg .s%d vs<2>;", this.typeSize * 8);
        this.append(".reg .u32 ix<2>;");
        this.append(".reg .u32 rt<3>;");
        this.format(".reg %s vl<2>;", this.typeName);
        this.append("ld.param.u64 data,[_data];");
        this.append("cvta.to.global.u64 data,data;");
        this.append("ld.param.u32 length,[_length];");
        this.append("mov.u64 sharedData,_sharedData;");
        this.append("mov.u32 blockDimX,%ntid.x;");
        this.append("mov.u32 rt0,%nctaid.x;");
        this.append("mov.u32 rt1,%ctaid.y;");
        this.append("mov.u32 rt2,%ctaid.x;");
        this.append("mad.lo.u32 baseIndex,rt0,rt1,rt2;");
        this.format("shl.b32 baseIndex,baseIndex,%d;", 9);
        this.append("mov.u32 workId,%tid.x;");
        this.append("bra loadTest;");
        this.append("loadLoop:");
        this.append("add.u32 globalIndex,baseIndex,workId;");
        this.format("mov%s vl0,%s;", this.typeName, this.maxValue);
        this.append("setp.lt.u32 p0,globalIndex,length;");
        this.format("@p0 mad.wide.u32 dataPtr,globalIndex,%d,data;", this.typeSize);
        this.format("@p0 ld.global%s vl0,[dataPtr];", this.typeName);
        this.format("mad.wide.u32 sharedPtr0,workId,%d,sharedData;", this.typeSize);
        this.format("st.shared%s [sharedPtr0],vl0;", this.typeName);
        this.append("add.u32 workId,workId,blockDimX;");
        this.append("loadTest:");
        this.format("setp.lt.u32 p0,workId,%d;", 512);
        this.append("@p0 bra loadLoop;");
        for (int i = 0; i < 9; ++i) {
            for (int j = 0; j <= i; ++j) {
                this.append("bar.sync 0;");
                String string = String.format("workLoop_%d_%d", i + 1, j + 1);
                String string2 = String.format("workTest_%d_%d", i + 1, j + 1);
                this.append("mov.u32 workId,%tid.x;");
                this.format("bra %s;", string2);
                this.format("%s:", string);
                if (j == i) {
                    this.append("shl.b32 ix0,workId,1;");
                } else {
                    this.append("shl.b32 ix0,workId,1;");
                    this.format("and.b32 rt0,workId,%s;", this.constant((1 << i - j) - 1));
                    this.append("sub.u32 ix0,ix0,rt0;");
                }
                if (j == 0 && j != i) {
                    this.format("xor.b32 ix1,ix0,%s;", this.constant((2 << i) - 1));
                } else {
                    this.format("add.u32 ix1,ix0,%s;", this.constant(1 << i - j));
                }
                this.format("mad.wide.u32 sharedPtr0,ix0,%d,sharedData;", this.typeSize);
                this.format("ld.shared%s vl0,[sharedPtr0];", this.typeName);
                this.format("mad.wide.u32 sharedPtr1,ix1,%d,sharedData;", this.typeSize);
                this.format("ld.shared%s vl1,[sharedPtr1];", this.typeName);
                this.compare(0, 1);
                this.format("@p0 st.shared%s [sharedPtr0],vl1;", this.typeName);
                this.format("@p0 st.shared%s [sharedPtr1],vl0;", this.typeName);
                this.append("add.u32 workId,workId,blockDimX;");
                this.format("%s:", string2);
                this.format("setp.lt.u32 p0,workId,%d;", 256);
                this.format("@p0 bra %s;", string);
            }
        }
        this.append("bar.sync 0;");
        this.append("mov.u32 workId,%tid.x;");
        this.append("bra storeTest;");
        this.append("storeLoop:");
        this.append("{");
        this.append("add.u32 globalIndex,baseIndex,workId;");
        this.append("setp.lt.u32 p0,globalIndex,length;");
        this.format("@p0 mad.wide.u32 sharedPtr0,workId,%d,sharedData;", this.typeSize);
        this.format("@p0 ld.shared%s vl0,[sharedPtr0];", this.typeName);
        this.format("@p0 mad.wide.u32 dataPtr,globalIndex,%d,data;", this.typeSize);
        this.format("@p0 st.global%s [dataPtr],vl0;", this.typeName);
        this.append("add.u32 workId,workId,blockDimX;");
        this.append("}");
        this.append("storeTest:");
        this.format("setp.lt.u32 p0,workId,%d;", 512);
        this.append("@p0 bra storeLoop;");
        this.append("}");
    }

    private void emitKernel(boolean bl) throws IOException {
        this.append(".visible .entry");
        this.format("%s%d(.param .u64 _data,.param .u32 _length,.param .u32 _stride)", bl ? "first" : "other", this.stepCount);
        this.append(".maxntid 256,1,1");
        this.append("{");
        this.declareLocals();
        this.append("ld.param.u64 data,[_data];");
        this.append("cvta.to.global.u64 data,data;");
        this.append("ld.param.u32 length,[_length];");
        this.append("ld.param.u32 stride,[_stride];");
        this.computeIndices(bl);
        this.gatherData();
        this.sortLocally(bl);
        this.scatterData();
        this.append("}");
    }

    private void emitPreamble(int n) throws IOException {
        this.append(".version 3.2");
        this.format(".target sm_%d", n < 3 ? 20 : 30);
        this.append(".address_size 64");
    }

    private void format(String string, Object ... objectArray) throws IOException {
        this.append(String.format(string, objectArray));
    }

    private void gatherData() throws IOException {
        int n = 1 << this.stepCount;
        for (int i = 0; i < n; ++i) {
            this.format("mov%s vl%d,%s;", this.typeName, i, this.maxValue);
            this.format("setp.lt.u32 p0,ix%d,length;", i);
            this.format("@p0 mad.wide.u32 ptr,ix%d,%d,data;", i, this.typeSize);
            this.format("@p0 ld.global%s vl%d,[ptr];", this.typeName, i);
        }
    }

    private void generate(int n) throws IOException {
        this.emitPreamble(n);
        this.stepCount = 1;
        while (true) {
            this.emitKernel(false);
            if (this.stepCount == 4) break;
            ++this.stepCount;
        }
        this.emitKernel(true);
        this.emitFirstPhases();
        this.writer.flush();
    }

    private void scatterData() throws IOException {
        int n = 1 << this.stepCount;
        for (int i = 0; i < n; ++i) {
            this.format("and.b32 bit,moved,%s;", this.constant(1 << i));
            this.append("setp.ne.b32 p0,bit,0;");
            this.format("@p0 mad.wide.u32 ptr,ix%d,%d,data;", i, this.typeSize);
            this.format("@p0 st.global%s [ptr],vl%d;", this.typeName, i);
        }
    }

    private void sortLocally(boolean bl) throws IOException {
        int n;
        int n2;
        int n3;
        int n4 = 1 << this.stepCount;
        int n5 = n4 >> 1;
        int n6 = 0;
        this.append("mov.b32 moved,0;");
        if (bl) {
            n3 = n4 - 1;
            for (n2 = 0; n2 < n5; ++n2) {
                n = n2 ^ n3;
                this.compareAndSwap(n2, n);
            }
            ++n6;
        }
        while (n6 < this.stepCount) {
            n3 = n4 >> n6 + 1;
            n2 = n3 << 1;
            for (n = 0; n < n4; n += n2) {
                for (int i = 0; i < n3; ++i) {
                    int n7 = n + (i & -n3) + i;
                    int n8 = n7 + n3;
                    this.compareAndSwap(n7, n8);
                }
            }
            ++n6;
        }
    }
}

