/*
 * Decompiled with CFR 0.152.
 */
package org.apache.livy.rsc.rpc;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.concurrent.EventExecutorGroup;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.concurrent.ImmediateEventExecutor;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
import java.io.Closeable;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import org.apache.livy.rsc.RSCConf;
import org.apache.livy.rsc.Utils;
import org.apache.livy.rsc.rpc.KryoMessageCodec;
import org.apache.livy.rsc.rpc.RpcDispatcher;
import org.apache.livy.rsc.rpc.RpcException;
import org.apache.livy.rsc.rpc.SaslHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Rpc
implements Closeable {
    private static final Logger LOG = LoggerFactory.getLogger(Rpc.class);
    static final String SASL_REALM = "rsc";
    static final String SASL_USER = "rsc";
    static final String SASL_PROTOCOL = "rsc";
    static final String SASL_AUTH_CONF = "auth-conf";
    private final RSCConf config;
    private final AtomicBoolean rpcClosed;
    private final AtomicLong rpcId;
    private final Channel channel;
    private final EventExecutorGroup egroup;
    private volatile RpcDispatcher dispatcher;
    private final Map<Class<?>, Method> handlers = new ConcurrentHashMap();
    private final Collection<OutstandingRpc> rpcCalls = new ConcurrentLinkedQueue<OutstandingRpc>();
    private volatile MessageHeader lastHeader;

    public static Promise<Rpc> createClient(final RSCConf config, final EventLoopGroup eloop, String host, int port, final String clientId, final String secret, final RpcDispatcher dispatcher) throws Exception {
        int connectTimeoutMs = (int)config.getTimeAsMs(RSCConf.Entry.RPC_CLIENT_CONNECT_TIMEOUT);
        final ChannelFuture cf = ((Bootstrap)((Bootstrap)((Bootstrap)((Bootstrap)((Bootstrap)new Bootstrap().group(eloop)).handler((ChannelHandler)new ChannelInboundHandlerAdapter(){})).channel(NioSocketChannel.class)).option(ChannelOption.SO_KEEPALIVE, (Object)true)).option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (Object)connectTimeoutMs)).connect(host, port);
        final Promise promise = eloop.next().newPromise();
        AtomicReference rpc = new AtomicReference();
        Runnable timeoutTask = new Runnable(){

            @Override
            public void run() {
                promise.setFailure((Throwable)new TimeoutException("Timed out waiting for RPC server connection."));
            }
        };
        final ScheduledFuture timeoutFuture = eloop.schedule(timeoutTask, config.getTimeAsMs(RSCConf.Entry.RPC_CLIENT_HANDSHAKE_TIMEOUT), TimeUnit.MILLISECONDS);
        cf.addListener((GenericFutureListener)new ChannelFutureListener(){

            public void operationComplete(ChannelFuture cf) throws Exception {
                if (cf.isSuccess()) {
                    SaslClientHandler saslHandler = new SaslClientHandler(config, clientId, (Promise<Rpc>)promise, timeoutFuture, secret, dispatcher);
                    Rpc rpc = Rpc.createRpc(config, saslHandler, (SocketChannel)cf.channel(), (EventExecutorGroup)eloop);
                    saslHandler.rpc = rpc;
                    saslHandler.sendHello(cf.channel());
                } else {
                    promise.setFailure(cf.cause());
                }
            }
        });
        promise.addListener((GenericFutureListener)new GenericFutureListener<Promise<Rpc>>(){

            public void operationComplete(Promise<Rpc> p) {
                if (p.isCancelled()) {
                    cf.cancel(true);
                }
            }
        });
        return promise;
    }

    static Rpc createServer(SaslHandler saslHandler, RSCConf config, SocketChannel channel, EventExecutorGroup egroup) throws IOException {
        return Rpc.createRpc(config, saslHandler, channel, egroup);
    }

    private static Rpc createRpc(RSCConf config, SaslHandler saslHandler, SocketChannel client, EventExecutorGroup egroup) throws IOException {
        LogLevel logLevel = LogLevel.TRACE;
        String logLevelStr = config.get(RSCConf.Entry.RPC_CHANNEL_LOG_LEVEL);
        if (logLevelStr != null) {
            try {
                logLevel = LogLevel.valueOf((String)logLevelStr);
            }
            catch (Exception e) {
                LOG.warn("Invalid log level {}, reverting to default.", (Object)logLevelStr);
            }
        }
        boolean logEnabled = false;
        switch (logLevel) {
            case DEBUG: {
                logEnabled = LOG.isDebugEnabled();
                break;
            }
            case ERROR: {
                logEnabled = LOG.isErrorEnabled();
                break;
            }
            case INFO: {
                logEnabled = LOG.isInfoEnabled();
                break;
            }
            case TRACE: {
                logEnabled = LOG.isTraceEnabled();
                break;
            }
            case WARN: {
                logEnabled = LOG.isWarnEnabled();
            }
        }
        if (logEnabled) {
            client.pipeline().addLast("logger", (ChannelHandler)new LoggingHandler(Rpc.class, logLevel));
        }
        KryoMessageCodec kryo = new KryoMessageCodec(config.getInt(RSCConf.Entry.RPC_MAX_MESSAGE_SIZE), MessageHeader.class, NullMessage.class, SaslMessage.class);
        saslHandler.setKryoMessageCodec(kryo);
        client.pipeline().addLast("codec", (ChannelHandler)kryo).addLast("sasl", (ChannelHandler)saslHandler);
        return new Rpc(config, (Channel)client, egroup);
    }

    static Rpc createEmbedded(RpcDispatcher dispatcher) {
        EmbeddedChannel c = new EmbeddedChannel(new ChannelHandler[]{new LoggingHandler(Rpc.class), new KryoMessageCodec(0, MessageHeader.class, NullMessage.class), dispatcher});
        Rpc rpc = new Rpc(new RSCConf(null), (Channel)c, (EventExecutorGroup)ImmediateEventExecutor.INSTANCE);
        rpc.dispatcher = dispatcher;
        dispatcher.registerRpc((Channel)c, rpc);
        return rpc;
    }

    private Rpc(RSCConf config, Channel channel, EventExecutorGroup egroup) {
        Utils.checkArgument(channel != null);
        Utils.checkArgument(egroup != null);
        this.config = config;
        this.channel = channel;
        this.dispatcher = null;
        this.egroup = egroup;
        this.rpcClosed = new AtomicBoolean();
        this.rpcId = new AtomicLong();
        channel.pipeline().addLast("monitor", (ChannelHandler)new ChannelInboundHandlerAdapter(){

            public void channelInactive(ChannelHandlerContext ctx) throws Exception {
                Rpc.this.close();
                super.channelInactive(ctx);
            }
        });
    }

    protected String name() {
        return this.getClass().getSimpleName();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public void handleMsg(ChannelHandlerContext ctx, Object msg, Class<?> handleClass, Object obj) throws Exception {
        if (this.lastHeader == null) {
            if (!(msg instanceof MessageHeader)) {
                LOG.warn("[{}] Expected RPC header, got {} instead.", (Object)this.name(), msg != null ? msg.getClass().getName() : null);
                throw new IllegalArgumentException();
            }
            this.lastHeader = (MessageHeader)msg;
            return;
        }
        LOG.debug("[{}] Received RPC message: type={} id={} payload={}", new Object[]{this.name(), this.lastHeader.type, this.lastHeader.id, msg != null ? msg.getClass().getName() : null});
        try {
            switch (this.lastHeader.type) {
                case CALL: {
                    this.handleCall(ctx, msg, handleClass, obj);
                    return;
                }
                case REPLY: {
                    this.handleReply(ctx, msg, this.findRpcCall(this.lastHeader.id));
                    return;
                }
                case ERROR: {
                    this.handleError(ctx, msg, this.findRpcCall(this.lastHeader.id));
                    return;
                }
                default: {
                    throw new IllegalArgumentException("Unknown RPC message type: " + (Object)((Object)this.lastHeader.type));
                }
            }
        }
        finally {
            this.lastHeader = null;
        }
    }

    private void handleCall(ChannelHandlerContext ctx, Object msg, Class<?> handleClass, Object obj) throws Exception {
        Method handler = this.handlers.get(msg.getClass());
        if (handler == null) {
            try {
                handler = handleClass.getDeclaredMethod("handle", ChannelHandlerContext.class, msg.getClass());
            }
            catch (NoSuchMethodException e) {
                try {
                    handler = handleClass.getMethod("handle", ChannelHandlerContext.class, msg.getClass());
                }
                catch (NoSuchMethodException e2) {
                    LOG.warn(String.format("[%s] Failed to find handler for msg '%s'.", this.name(), msg.getClass().getName()));
                    this.writeMessage(MessageType.ERROR, Utils.stackTraceAsString(e.getCause()));
                    return;
                }
            }
            handler.setAccessible(true);
            this.handlers.put(msg.getClass(), handler);
        }
        try {
            Object payload = handler.invoke(obj, ctx, msg);
            if (payload == null) {
                payload = new NullMessage();
            }
            this.writeMessage(MessageType.REPLY, payload);
        }
        catch (InvocationTargetException ite) {
            LOG.debug(String.format("[%s] Error in RPC handler.", this.name()), ite.getCause());
            this.writeMessage(MessageType.ERROR, Utils.stackTraceAsString(ite.getCause()));
        }
    }

    private void handleReply(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc) {
        rpc.future.setSuccess(msg instanceof NullMessage ? null : msg);
    }

    private void handleError(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc) {
        if (msg instanceof String) {
            LOG.warn("Received error message:{}.", msg);
            rpc.future.setFailure((Throwable)new RpcException((String)msg));
        } else {
            String error = String.format("Received error with unexpected payload (%s).", msg != null ? msg.getClass().getName() : null);
            LOG.warn(String.format("[%s] %s", this.name(), error));
            rpc.future.setFailure((Throwable)new IllegalArgumentException(error));
            ctx.close();
        }
    }

    private void writeMessage(MessageType replyType, Object payload) {
        this.channel.write((Object)new MessageHeader(this.lastHeader.id, replyType));
        this.channel.writeAndFlush(payload);
    }

    private OutstandingRpc findRpcCall(long id) {
        Iterator<OutstandingRpc> it = this.rpcCalls.iterator();
        while (it.hasNext()) {
            OutstandingRpc rpc = it.next();
            if (rpc.id != id) continue;
            it.remove();
            return rpc;
        }
        throw new IllegalArgumentException(String.format("Received RPC reply for unknown RPC (%d).", id));
    }

    private void registerRpcCall(long id, Promise<?> promise, String type) {
        LOG.debug("[{}] Registered outstanding rpc {} ({}).", new Object[]{this.name(), id, type});
        this.rpcCalls.add(new OutstandingRpc(id, promise));
    }

    private void discardRpcCall(long id) {
        LOG.debug("[{}] Discarding failed RPC {}.", (Object)this.name(), (Object)id);
        this.findRpcCall(id);
    }

    public void handleChannelException(ChannelHandlerContext ctx, Throwable cause) {
        if (LOG.isDebugEnabled()) {
            LOG.debug(String.format("[%s] Caught exception in channel pipeline.", this.name()), cause);
        } else {
            LOG.info(String.format("[%s] Caught exception in channel pipeline.", this.name()), cause);
        }
        if (this.lastHeader != null) {
            this.channel.write((Object)new MessageHeader(this.lastHeader.id, MessageType.ERROR));
            this.channel.writeAndFlush((Object)Utils.stackTraceAsString(cause));
            this.lastHeader = null;
        }
        ctx.close();
    }

    public void handleChannelInactive() {
        if (this.rpcCalls.size() > 0) {
            LOG.warn("[{}] Closing RPC channel with {} outstanding RPCs.", (Object)this.name(), (Object)this.rpcCalls.size());
            for (OutstandingRpc rpc : this.rpcCalls) {
                rpc.future.cancel(true);
            }
        } else {
            LOG.debug("Channel {} became inactive.", (Object)this.channel);
        }
    }

    public Future<Void> call(Object msg) {
        return this.call(msg, Void.class);
    }

    public <T> Future<T> call(final Object msg, Class<T> retType) {
        Utils.checkArgument(msg != null);
        Utils.checkState(this.channel.isOpen(), "RPC channel is closed.", new Object[0]);
        try {
            final long id = this.rpcId.getAndIncrement();
            final Promise promise = this.egroup.next().newPromise();
            final ChannelFutureListener listener = new ChannelFutureListener(){

                public void operationComplete(ChannelFuture cf) {
                    if (!cf.isSuccess() && !promise.isDone()) {
                        LOG.warn("Failed to send RPC, closing connection.", cf.cause());
                        promise.setFailure(cf.cause());
                        Rpc.this.discardRpcCall(id);
                        Rpc.this.close();
                    }
                }
            };
            this.registerRpcCall(id, promise, msg.getClass().getName());
            this.channel.eventLoop().submit(new Runnable(){

                @Override
                public void run() {
                    Rpc.this.channel.write((Object)new MessageHeader(id, MessageType.CALL)).addListener((GenericFutureListener)listener);
                    Rpc.this.channel.writeAndFlush(msg).addListener((GenericFutureListener)listener);
                }
            });
            return promise;
        }
        catch (Exception e) {
            throw Utils.propagate(e);
        }
    }

    public Channel getChannel() {
        return this.channel;
    }

    public void unRegisterRpc() {
        if (this.dispatcher != null) {
            this.dispatcher.unregisterRpc(this.channel);
        }
    }

    void setDispatcher(RpcDispatcher dispatcher) {
        Utils.checkNotNull((Object)dispatcher);
        Utils.checkState(this.dispatcher == null, "Dispatcher already set.", new Object[0]);
        this.dispatcher = dispatcher;
        this.channel.pipeline().addLast("dispatcher", (ChannelHandler)dispatcher);
        dispatcher.registerRpc(this.channel, this);
    }

    @Override
    public void close() {
        if (!this.rpcClosed.compareAndSet(false, true)) {
            return;
        }
        try {
            this.channel.close().sync();
        }
        catch (InterruptedException ie) {
            Thread.interrupted();
        }
    }

    private static class SaslClientHandler
    extends SaslHandler
    implements CallbackHandler {
        private final SaslClient client;
        private final String clientId;
        private final String secret;
        private final RpcDispatcher dispatcher;
        private Promise<Rpc> promise;
        private ScheduledFuture<?> timeout;
        private Rpc rpc;

        SaslClientHandler(RSCConf config, String clientId, Promise<Rpc> promise, ScheduledFuture<?> timeout, String secret, RpcDispatcher dispatcher) throws IOException {
            super(config);
            this.clientId = clientId;
            this.promise = promise;
            this.timeout = timeout;
            this.secret = secret;
            this.dispatcher = dispatcher;
            this.client = Sasl.createSaslClient(new String[]{config.get(RSCConf.Entry.SASL_MECHANISMS)}, null, "rsc", "rsc", config.getSaslOptions(), this);
        }

        @Override
        protected boolean isComplete() {
            return this.client.isComplete();
        }

        @Override
        protected String getNegotiatedProperty(String name) {
            return (String)this.client.getNegotiatedProperty(name);
        }

        @Override
        protected SaslMessage update(SaslMessage challenge) throws IOException {
            byte[] response = this.client.evaluateChallenge(challenge.payload);
            return response != null ? new SaslMessage(response) : null;
        }

        @Override
        public byte[] wrap(byte[] data, int offset, int len) throws IOException {
            return this.client.wrap(data, offset, len);
        }

        @Override
        public byte[] unwrap(byte[] data, int offset, int len) throws IOException {
            return this.client.unwrap(data, offset, len);
        }

        @Override
        public void dispose() throws IOException {
            if (!this.client.isComplete()) {
                this.onError(new SaslException("Client closed before SASL negotiation finished."));
            }
            this.client.dispose();
        }

        @Override
        protected void onComplete() throws Exception {
            this.timeout.cancel(true);
            this.rpc.setDispatcher(this.dispatcher);
            this.promise.setSuccess((Object)this.rpc);
            this.timeout = null;
            this.promise = null;
        }

        @Override
        protected void onError(Throwable error) {
            this.timeout.cancel(true);
            if (!this.promise.isDone()) {
                this.promise.setFailure(error);
            }
        }

        @Override
        public void handle(Callback[] callbacks) {
            for (Callback cb : callbacks) {
                if (cb instanceof NameCallback) {
                    ((NameCallback)cb).setName(this.clientId);
                    continue;
                }
                if (cb instanceof PasswordCallback) {
                    ((PasswordCallback)cb).setPassword(this.secret.toCharArray());
                    continue;
                }
                if (!(cb instanceof RealmCallback)) continue;
                RealmCallback rb = (RealmCallback)cb;
                rb.setText(rb.getDefaultText());
            }
        }

        void sendHello(Channel c) throws Exception {
            byte[] hello = this.client.hasInitialResponse() ? this.client.evaluateChallenge(new byte[0]) : new byte[]{};
            c.writeAndFlush((Object)new SaslMessage(this.clientId, hello)).sync();
        }
    }

    static class SaslMessage {
        final String clientId;
        final byte[] payload;

        SaslMessage() {
            this(null, null);
        }

        SaslMessage(byte[] payload) {
            this(null, payload);
        }

        SaslMessage(String clientId, byte[] payload) {
            this.clientId = clientId;
            this.payload = payload;
        }
    }

    static class NullMessage {
        NullMessage() {
        }
    }

    static class MessageHeader {
        final long id;
        final MessageType type;

        MessageHeader() {
            this(-1L, null);
        }

        MessageHeader(long id, MessageType type) {
            this.id = id;
            this.type = type;
        }
    }

    static enum MessageType {
        CALL,
        REPLY,
        ERROR;

    }

    private static class OutstandingRpc {
        final long id;
        final Promise<Object> future;

        OutstandingRpc(long id, Promise<?> future) {
            this.id = id;
            this.future = future;
        }
    }
}

