/* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ package org.opensearch.performanceanalyzer.net; import com.google.common.annotations.VisibleForTesting; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.NettyChannelBuilder; import io.netty.handler.ssl.SslContextBuilder; import java.io.File; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.SSLException; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.performanceanalyzer.CertificateUtils; import org.opensearch.performanceanalyzer.commons.collectors.StatsCollector; import org.opensearch.performanceanalyzer.commons.stats.metrics.StatExceptionCode; import org.opensearch.performanceanalyzer.grpc.InterNodeRpcServiceGrpc; import org.opensearch.performanceanalyzer.grpc.InterNodeRpcServiceGrpc.InterNodeRpcServiceStub; import org.opensearch.performanceanalyzer.rca.framework.util.InstanceDetails; /** * Class that manages the channel to other hosts in the cluster. It also performs staleness checks, * and initiates a new connection if it deems a channel to have gone stale. * *

It also listens to cluster state changes and manages handling connections to the changed * hosts. */ public class GRPCConnectionManager { private static final Logger LOG = LogManager.getLogger(GRPCConnectionManager.class); private static final int MAX_RETRY_ATTEMPTS = 2; private final int port; // TLS certificate, private key, and trusted root CA files private File certFile; private File pkeyFile; private File trustedCasFile; /** Map of remote hostId to a Netty channel to that host. */ private ConcurrentMap> perHostChannelMap = new ConcurrentHashMap<>(); /** * Map of remote hostId to a grpc client object of that host. The client objects are created * over the channels for those hosts and are used to call RPC methods on the hosts. */ private ConcurrentMap> perHostClientStubMap = new ConcurrentHashMap<>(); /** Flag that controls if we need to use a secure or an insecure channel. */ private final boolean shouldUseHttps; public GRPCConnectionManager(final boolean shouldUseHttps) { this.shouldUseHttps = shouldUseHttps; this.port = 0; if (shouldUseHttps) { this.certFile = CertificateUtils.getClientCertificateFile(); this.pkeyFile = CertificateUtils.getClientPrivateKeyFile(); this.trustedCasFile = CertificateUtils.getClientTrustedCasFile(); } } /** * Constructor that allows you to specify which port a client should connect to * * @param shouldUseHttps Whether to enable TLS * @param port The port number that client stubs should attempt to connect to */ public GRPCConnectionManager(final boolean shouldUseHttps, int port) { this.shouldUseHttps = shouldUseHttps; this.port = port; if (shouldUseHttps) { this.certFile = CertificateUtils.getClientCertificateFile(); this.pkeyFile = CertificateUtils.getClientPrivateKeyFile(); this.trustedCasFile = CertificateUtils.getClientTrustedCasFile(); } } @VisibleForTesting public ConcurrentMap> getPerHostChannelMap() { return perHostChannelMap; } @VisibleForTesting public ConcurrentMap> getPerHostClientStubMap() { return perHostClientStubMap; } /** * Gets the client stub(on which the rpcs can be initiated) for a host. * * @param remoteHost The host to which we want to make an RPC to. * @return The stub object. */ public InterNodeRpcServiceStub getClientStubForHost(final InstanceDetails remoteHost) { final AtomicReference stubAtomicReference = perHostClientStubMap.get(remoteHost.getInstanceId()); if (stubAtomicReference != null) { return stubAtomicReference.get(); } return addOrUpdateClientStubForHost(remoteHost); } /** * Builds or updates a stub object for host. Callers: The subscription send thread, the flow * unit send thread. * * @param remoteHost The host to which an RPC needs to be made. * @return The stub object. */ private synchronized InterNodeRpcServiceStub addOrUpdateClientStubForHost( final InstanceDetails remoteHost) { final InterNodeRpcServiceStub stub = buildStubForHost(remoteHost); perHostClientStubMap.computeIfAbsent( remoteHost.getInstanceId(), s -> new AtomicReference<>()); perHostClientStubMap.get(remoteHost.getInstanceId()).set(stub); return stub; } public void shutdown() { removeAllStubs(); terminateAllConnections(); } private ManagedChannel getChannelForHost(final InstanceDetails remoteHost) { final AtomicReference managedChannelAtomicReference = perHostChannelMap.get(remoteHost.getInstanceId()); if (managedChannelAtomicReference != null) { return managedChannelAtomicReference.get(); } return addOrUpdateChannelForHost(remoteHost); } /** * Builds or updates a channel object to be used by a client stub. Callers: Send flow unit * thread, send subscription thread. * * @param remoteHost The host to which we want to establish a channel to. * @return a Managed channel object. */ private synchronized ManagedChannel addOrUpdateChannelForHost( final InstanceDetails remoteHost) { final ManagedChannel channel = buildChannelForHost(remoteHost); perHostChannelMap.computeIfAbsent(remoteHost.getInstanceId(), s -> new AtomicReference<>()); perHostChannelMap.get(remoteHost.getInstanceId()).set(channel); return channel; } private ManagedChannel buildChannelForHost(final InstanceDetails remoteHost) { return shouldUseHttps ? buildSecureChannel(remoteHost) : buildInsecureChannel(remoteHost); } private int getPortFromHost(final InstanceDetails remoteHost) { int port = this.port != 0 ? this.port : remoteHost.getGrpcPort(); if (port == -1) { throw new IllegalArgumentException("Invalid port for grpc: " + port); } return port; } private ManagedChannel buildInsecureChannel(final InstanceDetails remoteHost) { return ManagedChannelBuilder.forAddress( remoteHost.getInstanceIp().toString(), getPortFromHost(remoteHost)) .usePlaintext() .enableRetry() .maxRetryAttempts(MAX_RETRY_ATTEMPTS) .build(); } private ManagedChannel buildSecureChannel(final InstanceDetails remoteHost) { try { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient().keyManager(certFile, pkeyFile); if (trustedCasFile != null) { sslContextBuilder.trustManager(trustedCasFile); } return NettyChannelBuilder.forAddress( remoteHost.getInstanceIp().toString(), getPortFromHost(remoteHost)) .sslContext(sslContextBuilder.build()) .enableRetry() .maxRetryAttempts(MAX_RETRY_ATTEMPTS) .build(); } catch (SSLException e) { LOG.error("Unable to build an SSL gRPC client.", e); // Wrap the SSL Exception in a generic RTE and re-throw. throw new RuntimeException(e); } } private InterNodeRpcServiceStub buildStubForHost(final InstanceDetails remoteHost) { return InterNodeRpcServiceGrpc.newStub(getChannelForHost(remoteHost)); } private void removeAllStubs() { for (Map.Entry> entry : perHostClientStubMap.entrySet()) { LOG.debug("Removing client stub for host: {}", entry.getKey()); perHostClientStubMap.remove(entry.getKey()); } } private void terminateAllConnections() { for (Map.Entry> entry : perHostChannelMap.entrySet()) { LOG.debug("shutting down connection to host: {}", entry.getKey()); ManagedChannel channel = entry.getValue().get(); channel.shutdownNow(); try { if (!channel.awaitTermination(1, TimeUnit.MINUTES)) { StatsCollector.instance() .logException(StatExceptionCode.GRPC_CHANNEL_CLOSURE_ERROR); LOG.warn("Unable to close channel gracefully for host: {}", entry.getKey()); } } catch (InterruptedException e) { LOG.warn("Channel interrupted while shutting down", e); channel.shutdownNow(); Thread.currentThread().interrupt(); } perHostChannelMap.remove(entry.getKey()); } } /** * Removes the stub and the channel object for the host. * * @param remoteHost the host to which we want to terminate connection from. */ public void terminateConnection(InstanceDetails.Id remoteHost) { perHostClientStubMap.remove(remoteHost); perHostChannelMap.remove(remoteHost); } }