/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kudu.client;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.Message;
import com.google.protobuf.MessageLite;
import com.google.protobuf.MessageOrBuilder;
import com.google.protobuf.TextFormat;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.ssl.SslHandler;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.List;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import org.apache.kudu.client.Bytes;
import org.apache.kudu.client.CallResponse;
import org.apache.kudu.client.KuduRpc;
import org.apache.kudu.client.Negotiator;
import org.apache.kudu.client.RpcOutboundMessage;
import org.apache.kudu.client.SecurityContext;
import org.apache.kudu.rpc.RpcHeader;
import org.apache.kudu.security.Token;
import org.apache.kudu.test.junit.RetryRule;
import org.apache.kudu.util.SecurityUtil;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TestNegotiator {
    static final Logger LOG = LoggerFactory.getLogger(TestNegotiator.class);
    private EmbeddedChannel embedder;
    private SecurityContext secContext;
    private SSLEngine serverEngine;
    private static final char[] KEYSTORE_PASSWORD = "password".toCharArray();
    private static final String CA_CERT_DER = "-----BEGIN CERTIFICATE-----\nMIIDXTCCAkWgAwIBAgIJAOOmFHYkBz4rMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwHhcNMTYxMTAyMjI0OTQ5WhcNMTcwMjEwMjI0OTQ5WjBFMQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAppo9GwiDisQVYAF9NXl8ykqo0MIi5rfNwiE9kUWbZ2ejzxs+1Cf7WCn4mzbkJx5ZscRjhnNb6dJxtZJeid/qgiNVBcNzh35H8J+ao0tEbHjCs7rKOX0etsFUp4GQwYkdfpvVBsU8ciXvkxhvt1XjSU3/YJJRAvCyGVxUQlKiVKGCD4OnFNBwMdNw7qI8ryiRv++7I9udfSuM713yMeBtkkV7hWUfxrTgQOLsV/CS+TsSoOJ7JJqHozeZ+VYom85UqSfpIFJVzM6S7BTb6SX/vwYIoS70gubT3HbHgDRcMvpCye1npHL9fL7B87XZn7wnnUem0eeCqWyUjJ82Uj9mQQIDAQABo1AwTjAdBgNVHQ4EFgQUOY7rpWGoZMrmyRZ9RohPWVwyPBowHwYDVR0jBBgwFoAUOY7rpWGoZMrmyRZ9RohPWVwyPBowDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEATKh3io8ruqbhmopY3xQWA2pEhs4ZSu3H+AfULMruVsXKEZjWp27nTsFaxLZYUlzeZr0EcWwZ79qkcA8Dyj+mVHhrCAPpcjsDACh1ZdUQAgASkVS4VQvkukct3DFa3y0lz5VwQIxjoQR5y6dCvxxXT9NpRo/Z7pd4MRhEbz3NT6PScQ9f2MTrR0NOikLdB98JlpKQbEKxzbMhWDw4J3mrmK6zdemjdCcRDsBVPswKnyAjkibXaZkpNRzjvDNAgO88MKlArCYoyRZqIfkcSXAwwTdGQ+5GQLsY9zS49Rrhk9R7eOmDhaHybdRBDqW1JiCSmzURZAxlnrjox4GmC3JJaA==\n-----END CERTIFICATE-----";
    @Rule
    public RetryRule retryRule = new RetryRule();

    @Before
    public void setUp() {
        this.serverEngine = this.createServerEngine();
        this.serverEngine.setUseClientMode(false);
        this.secContext = new SecurityContext();
    }

    private void startNegotiation(boolean fakeLoopback) {
        Negotiator negotiator = new Negotiator("127.0.0.1", this.secContext, false);
        negotiator.overrideLoopbackForTests = fakeLoopback;
        this.embedder = new EmbeddedChannel(new ChannelHandler[]{negotiator});
        negotiator.sendHello(this.embedder.pipeline().firstContext());
    }

    static CallResponse fakeResponse(RpcHeader.ResponseHeader header, Message body) {
        ByteBuf buf = Unpooled.buffer();
        KuduRpc.toByteBuf((ByteBuf)buf, (Message)header, (Message)body);
        buf = buf.slice(4, buf.readableBytes() - 4);
        return new CallResponse(buf);
    }

    KeyStore loadTestKeystore() throws Exception {
        KeyStore ks = KeyStore.getInstance("JKS");
        try (InputStream stream = TestNegotiator.class.getResourceAsStream("/test-key-and-cert.jks");){
            ks.load(stream, KEYSTORE_PASSWORD);
        }
        return ks;
    }

    SSLEngine createServerEngine() {
        try {
            KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
            kmf.init(this.loadTestKeystore(), KEYSTORE_PASSWORD);
            SSLContext ctx = SSLContext.getInstance("TLS");
            ctx.init(kmf.getKeyManagers(), null, null);
            return ctx.createSSLEngine();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private Negotiator.Success assertComplete(boolean isTls) throws Exception {
        RpcOutboundMessage msg = isTls ? this.unwrapOutboundMessage((ByteBuf)this.embedder.readOutbound(), (Message.Builder)RpcHeader.ConnectionContextPB.newBuilder()) : (RpcOutboundMessage)this.embedder.readOutbound();
        RpcHeader.ConnectionContextPB connCtx = (RpcHeader.ConnectionContextPB)msg.getBody();
        Assert.assertEquals((long)-3L, (long)msg.getHeaderBuilder().getCallId());
        Assert.assertEquals((Object)System.getProperty("user.name"), (Object)connCtx.getDEPRECATEDUserInfo().getRealUser());
        Negotiator.Success success = (Negotiator.Success)this.embedder.readInbound();
        Assert.assertNotNull((Object)success);
        return success;
    }

    @Test
    public void testChannelBinding() throws Exception {
        KeyStore ks = this.loadTestKeystore();
        Certificate cert = ks.getCertificate("1");
        byte[] bindings = SecurityUtil.getEndpointChannelBindings((Certificate)cert);
        Assert.assertEquals((long)32L, (long)bindings.length);
    }

    @Test
    public void testNegotiation() throws Exception {
        this.startNegotiation(false);
        RpcOutboundMessage msg = (RpcOutboundMessage)this.embedder.readOutbound();
        RpcHeader.NegotiatePB body = (RpcHeader.NegotiatePB)msg.getBody();
        Assert.assertEquals((long)-33L, (long)msg.getHeaderBuilder().getCallId());
        Assert.assertEquals((Object)RpcHeader.NegotiatePB.NegotiateStep.NEGOTIATE, (Object)body.getStep());
        this.embedder.writeInbound(new Object[]{TestNegotiator.fakeResponse(RpcHeader.ResponseHeader.newBuilder().setCallId(-33).build(), (Message)RpcHeader.NegotiatePB.newBuilder().addSaslMechanisms(RpcHeader.NegotiatePB.SaslMechanism.newBuilder().setMechanism("PLAIN")).setStep(RpcHeader.NegotiatePB.NegotiateStep.NEGOTIATE).build())});
        this.embedder.flushInbound();
        msg = (RpcOutboundMessage)this.embedder.readOutbound();
        body = (RpcHeader.NegotiatePB)msg.getBody();
        Assert.assertEquals((long)-33L, (long)msg.getHeaderBuilder().getCallId());
        Assert.assertEquals((Object)RpcHeader.NegotiatePB.NegotiateStep.SASL_INITIATE, (Object)body.getStep());
        Assert.assertEquals((long)1L, (long)body.getSaslMechanismsCount());
        Assert.assertEquals((Object)"PLAIN", (Object)body.getSaslMechanisms(0).getMechanism());
        Assert.assertTrue((boolean)body.hasToken());
        this.embedder.writeInbound(new Object[]{TestNegotiator.fakeResponse(RpcHeader.ResponseHeader.newBuilder().setCallId(-33).build(), (Message)RpcHeader.NegotiatePB.newBuilder().setStep(RpcHeader.NegotiatePB.NegotiateStep.SASL_SUCCESS).build())});
        this.embedder.flushInbound();
        this.assertComplete(false);
    }

    private static void runTasks(SSLEngineResult result, SSLEngine engine) {
        Runnable task;
        if (result.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NEED_TASK) {
            return;
        }
        while ((task = engine.getDelegatedTask()) != null) {
            task.run();
        }
    }

    private static CallResponse runServerStep(SSLEngine engine, ByteString clientTlsMessage) throws SSLException {
        LOG.debug("Handling TLS message from client: {}", (Object)Bytes.hex((byte[])clientTlsMessage.toByteArray()));
        ByteBuffer dst = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
        ByteBuffer src = ByteBuffer.wrap(clientTlsMessage.toByteArray());
        do {
            SSLEngineResult result = engine.unwrap(src, dst);
            TestNegotiator.runTasks(result, engine);
        } while (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP);
        if (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
            ArrayList bufs = Lists.newArrayList();
            while (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP) {
                dst.clear();
                TestNegotiator.runTasks(engine.wrap(ByteBuffer.allocate(0), dst), engine);
                dst.flip();
                bufs.add(ByteString.copyFrom((ByteBuffer)dst));
            }
            return TestNegotiator.fakeResponse(RpcHeader.ResponseHeader.newBuilder().setCallId(-33).build(), (Message)RpcHeader.NegotiatePB.newBuilder().setTlsHandshake(ByteString.copyFrom((Iterable)bufs)).setStep(RpcHeader.NegotiatePB.NegotiateStep.TLS_HANDSHAKE).build());
        }
        if (engine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
            return null;
        }
        throw new AssertionError((Object)("unexpected state: " + (Object)((Object)engine.getHandshakeStatus())));
    }

    private void runTlsHandshake(boolean isAuthOnly) throws SSLException {
        RpcOutboundMessage msg = (RpcOutboundMessage)this.embedder.readOutbound();
        RpcHeader.NegotiatePB body = (RpcHeader.NegotiatePB)msg.getBody();
        Assert.assertEquals((Object)RpcHeader.NegotiatePB.NegotiateStep.TLS_HANDSHAKE, (Object)body.getStep());
        this.embedder.writeInbound(new Object[]{TestNegotiator.runServerStep(this.serverEngine, body.getTlsHandshake())});
        this.embedder.flushInbound();
        msg = (RpcOutboundMessage)this.embedder.readOutbound();
        body = (RpcHeader.NegotiatePB)msg.getBody();
        Assert.assertEquals((Object)RpcHeader.NegotiatePB.NegotiateStep.TLS_HANDSHAKE, (Object)body.getStep());
        if (!isAuthOnly) {
            this.embedder.pipeline().addFirst("encode-outbound", (ChannelHandler)new RpcOutboundMessage.Encoder());
        }
        this.embedder.writeInbound(new Object[]{TestNegotiator.runServerStep(this.serverEngine, body.getTlsHandshake())});
        this.embedder.flushInbound();
    }

    @Test
    public void testTlsNegotiation() throws Exception {
        this.startNegotiation(false);
        RpcOutboundMessage msg = (RpcOutboundMessage)this.embedder.readOutbound();
        RpcHeader.NegotiatePB body = (RpcHeader.NegotiatePB)msg.getBody();
        Assert.assertEquals((Object)RpcHeader.NegotiatePB.NegotiateStep.NEGOTIATE, (Object)body.getStep());
        Assert.assertTrue((boolean)body.getSupportedFeaturesList().contains(RpcHeader.RpcFeatureFlag.TLS));
        this.embedder.writeInbound(new Object[]{TestNegotiator.fakeResponse(RpcHeader.ResponseHeader.newBuilder().setCallId(-33).build(), (Message)RpcHeader.NegotiatePB.newBuilder().addSaslMechanisms(RpcHeader.NegotiatePB.SaslMechanism.newBuilder().setMechanism("PLAIN")).addSupportedFeatures(RpcHeader.RpcFeatureFlag.TLS).setStep(RpcHeader.NegotiatePB.NegotiateStep.NEGOTIATE).build())});
        this.embedder.flushInbound();
        this.runTlsHandshake(false);
        Assert.assertTrue((boolean)(this.embedder.pipeline().first() instanceof SslHandler));
        msg = this.unwrapOutboundMessage((ByteBuf)this.embedder.readOutbound(), (Message.Builder)RpcHeader.NegotiatePB.newBuilder());
        body = (RpcHeader.NegotiatePB)msg.getBody();
        Assert.assertEquals((Object)RpcHeader.NegotiatePB.NegotiateStep.SASL_INITIATE, (Object)body.getStep());
    }

    @Test
    public void testTlsNegotiationAuthOnly() throws Exception {
        this.startNegotiation(true);
        RpcOutboundMessage msg = (RpcOutboundMessage)this.embedder.readOutbound();
        RpcHeader.NegotiatePB body = (RpcHeader.NegotiatePB)msg.getBody();
        Assert.assertEquals((Object)RpcHeader.NegotiatePB.NegotiateStep.NEGOTIATE, (Object)body.getStep());
        Assert.assertTrue((boolean)body.getSupportedFeaturesList().contains(RpcHeader.RpcFeatureFlag.TLS));
        Assert.assertTrue((boolean)body.getSupportedFeaturesList().contains(RpcHeader.RpcFeatureFlag.TLS_AUTHENTICATION_ONLY));
        this.embedder.writeInbound(new Object[]{TestNegotiator.fakeResponse(RpcHeader.ResponseHeader.newBuilder().setCallId(-33).build(), (Message)RpcHeader.NegotiatePB.newBuilder().addSaslMechanisms(RpcHeader.NegotiatePB.SaslMechanism.newBuilder().setMechanism("PLAIN")).addSupportedFeatures(RpcHeader.RpcFeatureFlag.TLS).addSupportedFeatures(RpcHeader.RpcFeatureFlag.TLS_AUTHENTICATION_ONLY).setStep(RpcHeader.NegotiatePB.NegotiateStep.NEGOTIATE).build())});
        this.embedder.flushInbound();
        this.runTlsHandshake(true);
        Assert.assertFalse((boolean)(this.embedder.pipeline().first() instanceof SslHandler));
        msg = (RpcOutboundMessage)this.embedder.readOutbound();
        body = (RpcHeader.NegotiatePB)msg.getBody();
        Assert.assertEquals((Object)RpcHeader.NegotiatePB.NegotiateStep.SASL_INITIATE, (Object)body.getStep());
    }

    @Test
    public void testNoTokenAuthWhenNoTrustedCerts() throws Exception {
        this.secContext.setAuthenticationToken(Token.SignedTokenPB.getDefaultInstance());
        this.startNegotiation(false);
        RpcOutboundMessage msg = (RpcOutboundMessage)this.embedder.readOutbound();
        RpcHeader.NegotiatePB body = (RpcHeader.NegotiatePB)msg.getBody();
        Assert.assertEquals((Object)"supported_features: APPLICATION_FEATURE_FLAGS supported_features: TLS step: NEGOTIATE authn_types { sasl { } }", (Object)TextFormat.shortDebugString((MessageOrBuilder)body));
    }

    @Test
    public void testTokenAuthWithTrustedCerts() throws Exception {
        this.secContext.trustCertificates((List)ImmutableList.of((Object)ByteString.copyFromUtf8((String)CA_CERT_DER)));
        this.secContext.setAuthenticationToken(Token.SignedTokenPB.getDefaultInstance());
        this.startNegotiation(false);
        RpcOutboundMessage msg = (RpcOutboundMessage)this.embedder.readOutbound();
        RpcHeader.NegotiatePB body = (RpcHeader.NegotiatePB)msg.getBody();
        Assert.assertEquals((Object)"supported_features: APPLICATION_FEATURE_FLAGS supported_features: TLS step: NEGOTIATE authn_types { sasl { } } authn_types { token { } }", (Object)TextFormat.shortDebugString((MessageOrBuilder)body));
        this.embedder.writeInbound(new Object[]{TestNegotiator.fakeResponse(RpcHeader.ResponseHeader.newBuilder().setCallId(-33).build(), (Message)RpcHeader.NegotiatePB.newBuilder().addSupportedFeatures(RpcHeader.RpcFeatureFlag.TLS).addAuthnTypes(RpcHeader.AuthenticationTypePB.newBuilder().setToken(RpcHeader.AuthenticationTypePB.Token.getDefaultInstance())).setStep(RpcHeader.NegotiatePB.NegotiateStep.NEGOTIATE).build())});
        this.embedder.flushInbound();
        this.runTlsHandshake(false);
        msg = this.unwrapOutboundMessage((ByteBuf)this.embedder.readOutbound(), (Message.Builder)RpcHeader.NegotiatePB.newBuilder());
        body = (RpcHeader.NegotiatePB)msg.getBody();
        Assert.assertEquals((Object)"step: TOKEN_EXCHANGE authn_token { }", (Object)TextFormat.shortDebugString((MessageOrBuilder)body));
        this.embedder.writeInbound(new Object[]{TestNegotiator.fakeResponse(RpcHeader.ResponseHeader.newBuilder().setCallId(-33).build(), (Message)RpcHeader.NegotiatePB.newBuilder().setStep(RpcHeader.NegotiatePB.NegotiateStep.TOKEN_EXCHANGE).build())});
        this.embedder.flushInbound();
        ByteBuf empty = (ByteBuf)this.embedder.readOutbound();
        Assert.assertEquals((long)0L, (long)empty.readableBytes());
        this.assertComplete(true);
    }

    private RpcOutboundMessage unwrapOutboundMessage(ByteBuf wrappedBuf, Message.Builder requestBuilder) throws Exception {
        SslHandler handler = new SslHandler(this.serverEngine);
        EmbeddedChannel serverSSLChannel = new EmbeddedChannel(new ChannelHandler[]{handler});
        serverSSLChannel.writeInbound(new Object[]{wrappedBuf});
        serverSSLChannel.flushInbound();
        ByteBuf unwrappedbuf = (ByteBuf)serverSSLChannel.readInbound();
        int size = unwrappedbuf.readInt();
        byte[] bytes = new byte[size];
        unwrappedbuf.getBytes(unwrappedbuf.readerIndex(), bytes);
        CodedInputStream in = CodedInputStream.newInstance((byte[])bytes);
        RpcHeader.RequestHeader.Builder header = RpcHeader.RequestHeader.newBuilder();
        in.readMessage((MessageLite.Builder)header, (ExtensionRegistryLite)ExtensionRegistry.getEmptyRegistry());
        in.readMessage((MessageLite.Builder)requestBuilder, (ExtensionRegistryLite)ExtensionRegistry.getEmptyRegistry());
        return new RpcOutboundMessage(header, requestBuilder.build());
    }
}

