/*
 * Decompiled with CFR 0.152.
 */
package io.privacyresearch.equation.proxy;

import io.privacyresearch.equation.proxy.CertificateUtils;
import io.privacyresearch.equation.proxy.HttpBridge;
import io.privacyresearch.equation.proxy.SignalHttpBridge;
import io.privacyresearch.equation.proxy.SignalWebSocketBridge;
import io.privacyresearch.equation.proxy.WebSocketBridge;
import io.privacyresearch.equation.proxy.WebSocketBridgeFactory;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.net.InetAddress;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Consumer;
import java.util.logging.Level;
import net.luminis.quic.QuicConnection;
import net.luminis.quic.QuicConnectionImpl;
import net.luminis.quic.QuicStream;
import net.luminis.quic.Version;
import net.luminis.quic.frame.PingFrame;
import net.luminis.quic.frame.QuicFrame;
import net.luminis.quic.log.Logger;
import net.luminis.quic.log.SysOutLogger;
import net.luminis.quic.server.ApplicationProtocolConnection;
import net.luminis.quic.server.ApplicationProtocolConnectionFactory;
import net.luminis.quic.server.ServerConnectionImpl;
import net.luminis.quic.server.ServerConnector;

public class QuicServerTransport {
    static final String PROTOCOL = "swave";
    static final String WS_PROTOCOL = "pwave";
    public final int port;
    public final String certificateFile;
    public final String keyFile;
    static ExecutorService executorService = Executors.newSingleThreadExecutor();
    private static final String DEFAULT_CERT = "/tmp/certs/server-cert.pem";
    private static final String DEFAULT_KEY = "/tmp/certs/server-key.pem";
    public static int DEFAULT_PORT = 9786;
    private final HttpBridge httpBridge;
    private final WebSocketBridgeFactory webSocketBridgeFactory;
    private static final java.util.logging.Logger LOG = java.util.logging.Logger.getLogger(QuicServerTransport.class.getName());

    public QuicServerTransport() {
        this(DEFAULT_PORT, DEFAULT_CERT, DEFAULT_KEY);
    }

    public QuicServerTransport(int port, String serverCert, String serverKey) {
        this(port, serverCert, serverKey, new SignalHttpBridge(), new SignalWebSocketBridge.SignalWebSocketBridgeFactory());
    }

    public QuicServerTransport(int port, String serverCert, String serverKey, HttpBridge hBridge, WebSocketBridgeFactory wBridgeFactory) {
        this.port = port;
        this.certificateFile = serverCert == null ? DEFAULT_CERT : serverCert;
        this.keyFile = serverKey == null ? DEFAULT_KEY : serverKey;
        try {
            CertificateUtils.generateCertificate(Path.of(this.certificateFile, new String[0]), Path.of(this.keyFile, new String[0]));
        }
        catch (Exception ex) {
            LOG.log(Level.SEVERE, null, ex);
        }
        this.httpBridge = hBridge;
        this.webSocketBridgeFactory = wBridgeFactory;
        LOG.info("Created KwikServerProcessor at port " + port + " with cert " + this.certificateFile + " and key = " + this.keyFile);
    }

    public void startProcessing() throws Exception {
        SysOutLogger log = new SysOutLogger();
        ArrayList<Version> supportedVersions = new ArrayList<Version>();
        supportedVersions.add(Version.QUIC_version_1);
        supportedVersions.add(Version.QUIC_version_2);
        boolean requireRetry = true;
        LOG.info("Start processing");
        ServerConnector serverConnector = new ServerConnector(this.port, (InputStream)new FileInputStream(this.certificateFile), (InputStream)new FileInputStream(this.keyFile), supportedVersions, requireRetry, (Logger)log);
        SWaveConnectionFactory swave = new SWaveConnectionFactory(this);
        serverConnector.registerApplicationProtocol(PROTOCOL, (ApplicationProtocolConnectionFactory)swave);
        PWaveConnectionFactory pwave = new PWaveConnectionFactory(this);
        serverConnector.registerApplicationProtocol(WS_PROTOCOL, (ApplicationProtocolConnectionFactory)pwave);
        LOG.info("Starting ServerConnector");
        serverConnector.start();
    }

    public static int readIntFromStream(InputStream is) throws IOException {
        byte[] buffer = QuicServerTransport.readNBytesFromStream(is, 4);
        int value = buffer[0] << 24 | (buffer[1] & 0xFF) << 16 | (buffer[2] & 0xFF) << 8 | buffer[3] & 0xFF;
        return value;
    }

    public static byte[] readNBytesFromStream(InputStream is, int total) throws IOException {
        byte[] buffer = new byte[total];
        for (int offset = 0; offset < total; offset += is.read(buffer, offset, total - offset)) {
        }
        return buffer;
    }

    static String t() {
        return "[" + Thread.currentThread().getId() + "]";
    }

    static class SWaveConnectionFactory
    implements ApplicationProtocolConnectionFactory {
        QuicServerTransport transport;

        public SWaveConnectionFactory(QuicServerTransport qst) {
            this.transport = qst;
        }

        public ApplicationProtocolConnection createConnection(String protocol, QuicConnection qc) {
            LOG.info("Create connection: protocol = " + protocol + ", qc = " + String.valueOf(qc));
            return new SWaveApplicationProtocolConnection(this.transport);
        }
    }

    static class PWaveConnectionFactory
    implements ApplicationProtocolConnectionFactory {
        private final QuicServerTransport transport;

        public PWaveConnectionFactory(QuicServerTransport transport) {
            this.transport = transport;
        }

        public ApplicationProtocolConnection createConnection(String protocol, QuicConnection qc) {
            ServerConnectionImpl sci = (ServerConnectionImpl)qc;
            LOG.info("Create connection: protocol = " + protocol + ", qc = " + String.valueOf(qc) + " and sourceIP = " + String.valueOf(sci.getInitialClientAddress()));
            return new PWaveApplicationProtocolConnection(this.transport, sci);
        }
    }

    public static class PWaveApplicationProtocolConnection
    implements ApplicationProtocolConnection {
        private final QuicServerTransport transport;
        private final QuicConnectionImpl quicConnection;
        private QuicStream stream;
        private long lastAck = 0L;
        private boolean activeClient = true;
        private boolean activeSending = true;
        private Thread clientHealthThread;
        private Thread websocketHealtThread;
        private WebSocketBridge sws;
        private InetAddress sourceAddress;
        private byte[] sourceId;
        private long connectionId;

        public PWaveApplicationProtocolConnection(QuicServerTransport transport, ServerConnectionImpl qc) {
            LOG.info("Create PWaveConnection for " + String.valueOf(qc) + " of class " + String.valueOf(qc.getClass()));
            this.transport = transport;
            this.quicConnection = qc;
            this.addHealthCheck();
            this.sourceAddress = qc.getInitialClientAddress();
            this.sourceId = qc.getSourceConnectionId();
            this.connectionId = new BigInteger(this.sourceId).longValue();
            this.startKeepAlivePing(qc);
            LOG.info("Got connection " + this.getConnectionInfo());
        }

        private void startKeepAlivePing(final ServerConnectionImpl con) {
            LOG.info("Start keepalive");
            this.websocketHealtThread = new Thread(this){
                final /* synthetic */ PWaveApplicationProtocolConnection this$0;
                {
                    this.this$0 = this$0;
                }

                @Override
                public void run() {
                    Consumer<QuicFrame> lost = qf -> System.err.println("LOST FRAME! " + String.valueOf(qf));
                    while (this.this$0.activeSending && this.this$0.activeClient) {
                        LOG.finest("PING " + String.valueOf(con) + ", last packet received = " + con.getLastPacketReceived());
                        PingFrame pingFrame = new PingFrame(Version.QUIC_version_2);
                        con.send((QuicFrame)pingFrame, lost);
                        try {
                            Thread.sleep(1000L);
                        }
                        catch (InterruptedException ex) {
                            LOG.log(Level.SEVERE, null, ex);
                        }
                        this.this$0.activeSending = !this.this$0.sws.isDestroyed();
                    }
                    LOG.warning("We close " + String.valueOf(con) + " and don't send keepalive pings anymore!");
                    con.close();
                }
            };
            this.websocketHealtThread.start();
        }

        private void addHealthCheck() {
            this.clientHealthThread = new Thread("Healtcheck"){

                @Override
                public void run() {
                    while (activeClient && activeSending) {
                        try {
                            Thread.sleep(120000L);
                            lastAck = Math.max(lastAck, quicConnection.getLastPacketReceived());
                            long elapsed = System.currentTimeMillis() - lastAck;
                            if (elapsed <= 120000L) continue;
                            LOG.warning("Didn't receive an ack from " + this.getConnectionInfo() + " for " + elapsed + " seconds. Remove this.");
                            activeClient = false;
                            if (sws == null) continue;
                            sws.close();
                        }
                        catch (InterruptedException ex) {
                            java.util.logging.Logger.getLogger(QuicServerTransport.class.getName()).log(Level.SEVERE, null, ex);
                        }
                    }
                    LOG.warning("We closed " + String.valueOf(sws) + " and don't send keepalive pings anymore!");
                }
            };
            this.clientHealthThread.start();
        }

        public void acceptPeerInitiatedStream(QuicStream stream) {
            if (this.stream != null) {
                LOG.severe("Got a request for a second stream, bailing.");
                throw new RuntimeException("We don't want a client to create 2 streams to the same pwave");
            }
            LOG.info(QuicServerTransport.t() + "PWave ACCEPTPeerInitiatedStream: " + String.valueOf(stream) + " with id = " + stream.getStreamId() + " for " + this.getConnectionInfo());
            this.stream = stream;
            this.startStreamProcess(stream);
        }

        private void startStreamProcess(final QuicStream stream) {
            LOG.info(QuicServerTransport.t() + "Got stream: " + String.valueOf(stream) + " and qc = " + this.getConnectionInfo());
            InputStream inputStream = stream.getInputStream();
            try {
                int baseUrlLength = QuicServerTransport.readIntFromStream(inputStream);
                String baseUrl = new String(QuicServerTransport.readNBytesFromStream(inputStream, baseUrlLength), StandardCharsets.UTF_8);
                LOG.info("Got baseurl = KWAAAK and qc = " + String.valueOf(this.quicConnection));
                int headersCount = QuicServerTransport.readIntFromStream(inputStream);
                HashMap<String, String> headers = new HashMap<String, String>();
                for (int i = 0; i < headersCount; ++i) {
                    int keyLength = QuicServerTransport.readIntFromStream(inputStream);
                    String key = new String(QuicServerTransport.readNBytesFromStream(inputStream, keyLength), StandardCharsets.UTF_8);
                    int valLength = QuicServerTransport.readIntFromStream(inputStream);
                    String val = new String(QuicServerTransport.readNBytesFromStream(inputStream, valLength), StandardCharsets.UTF_8);
                    headers.put(key, val);
                }
                LOG.info(QuicServerTransport.t() + "Got all data for websocket for " + this.getConnectionInfo());
                Consumer<byte[]> messageCallback = r -> this.processWebSocketRpcReplyMessage((byte[])r);
                this.sws = this.transport.webSocketBridgeFactory.createWebSocketBridge(this, baseUrl, headers, messageCallback);
                Thread t = new Thread(this){
                    final /* synthetic */ PWaveApplicationProtocolConnection this$0;
                    {
                        this.this$0 = this$0;
                    }

                    @Override
                    public void run() {
                        this.this$0.processIncomingQuikStream(stream, this.this$0.sws);
                    }
                };
                t.start();
            }
            catch (Exception ex) {
                ex.printStackTrace();
                java.util.logging.Logger.getLogger(QuicServerTransport.class.getName()).log(Level.SEVERE, null, ex);
            }
        }

        public String getConnectionInfo() {
            return "Connection from " + String.valueOf(this.sourceAddress) + " with id " + this.connectionId + " and hash " + String.valueOf(this.quicConnection);
        }

        void processIncomingQuikStream(QuicStream stream, WebSocketBridge destination) {
            InputStream is = stream.getInputStream();
            boolean listen = true;
            while (listen) {
                try {
                    LOG.info("Waiting for input on quicstream " + String.valueOf(stream) + " for " + this.getConnectionInfo());
                    int len = QuicServerTransport.readIntFromStream(is);
                    LOG.info("Got " + len + " bytes from other quic endpoint from " + this.getConnectionInfo());
                    byte[] payload = QuicServerTransport.readNBytesFromStream(is, len);
                    this.lastAck = System.currentTimeMillis();
                    LOG.info("Done reading " + payload.length + " bytes, now send them to websocket for " + this.getConnectionInfo());
                    destination.sendData(payload);
                    LOG.info("Done sending payload to websocket for connection " + this.getConnectionInfo());
                }
                catch (IOException ex) {
                    listen = false;
                    LOG.warning("Got an error reading quicstream from client to proxy. Stop reading and close WS for " + this.getConnectionInfo());
                    destination.close();
                }
            }
        }

        void processWebSocketRpcReplyMessage(byte[] raw) {
            executorService.submit(() -> {
                LOG.info(QuicServerTransport.t() + "Process a reply and send it via " + String.valueOf(this.stream) + " to " + this.getConnectionInfo());
                byte[] sizeBytes = ByteBuffer.allocate(4).putInt(raw.length).array();
                try {
                    this.stream.getOutputStream().write(sizeBytes);
                    this.stream.getOutputStream().write(raw);
                    this.stream.getOutputStream().flush();
                }
                catch (IOException ex) {
                    LOG.log(Level.SEVERE, "IOException while sending ws reply not yet handled.", ex);
                    ex.printStackTrace();
                }
                LOG.info(QuicServerTransport.t() + "Processed a reply and send it via " + String.valueOf(this.stream));
            });
        }
    }

    static class SWaveApplicationProtocolConnection
    implements ApplicationProtocolConnection {
        private final QuicServerTransport transport;

        public SWaveApplicationProtocolConnection(QuicServerTransport transport) {
            this.transport = transport;
        }

        public void acceptPeerInitiatedStream(QuicStream stream) {
            LOG.info("ACCEPTPeerInitiatedStream: " + String.valueOf(stream));
            this.startStreamProcess(stream);
        }

        private void startStreamProcess(final QuicStream stream) {
            final int streamId = stream.getStreamId();
            LOG.info(QuicServerTransport.t() + "Processing stream with id " + streamId);
            Thread t = new Thread(this){
                final /* synthetic */ SWaveApplicationProtocolConnection this$0;
                {
                    this.this$0 = this$0;
                }

                @Override
                public void run() {
                    try {
                        LOG.info(QuicServerTransport.t() + "Start reading from incoming stream for " + streamId);
                        InputStream is = stream.getInputStream();
                        byte[] allBytes = is.readAllBytes();
                        byte[] replyBytes = this.this$0.transport.httpBridge.sendRequestToSignalServer(allBytes);
                        stream.getOutputStream().write(replyBytes);
                        stream.getOutputStream().flush();
                        stream.getOutputStream().close();
                        LOG.info(QuicServerTransport.t() + "Done writing output to " + streamId);
                    }
                    catch (Exception ex) {
                        LOG.log(Level.SEVERE, null, ex);
                    }
                }
            };
            t.start();
        }
    }
}

