/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.performanceanalyzer.net;
import com.google.common.annotations.VisibleForTesting;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLException;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.performanceanalyzer.CertificateUtils;
import org.opensearch.performanceanalyzer.commons.collectors.StatsCollector;
import org.opensearch.performanceanalyzer.commons.stats.metrics.StatExceptionCode;
import org.opensearch.performanceanalyzer.grpc.InterNodeRpcServiceGrpc;
import org.opensearch.performanceanalyzer.grpc.InterNodeRpcServiceGrpc.InterNodeRpcServiceStub;
import org.opensearch.performanceanalyzer.rca.framework.util.InstanceDetails;
/**
* Class that manages the channel to other hosts in the cluster. It also performs staleness checks,
* and initiates a new connection if it deems a channel to have gone stale.
*
*
It also listens to cluster state changes and manages handling connections to the changed
* hosts.
*/
public class GRPCConnectionManager {
private static final Logger LOG = LogManager.getLogger(GRPCConnectionManager.class);
private static final int MAX_RETRY_ATTEMPTS = 2;
private final int port;
// TLS certificate, private key, and trusted root CA files
private File certFile;
private File pkeyFile;
private File trustedCasFile;
/** Map of remote hostId to a Netty channel to that host. */
private ConcurrentMap> perHostChannelMap =
new ConcurrentHashMap<>();
/**
* Map of remote hostId to a grpc client object of that host. The client objects are created
* over the channels for those hosts and are used to call RPC methods on the hosts.
*/
private ConcurrentMap>
perHostClientStubMap = new ConcurrentHashMap<>();
/** Flag that controls if we need to use a secure or an insecure channel. */
private final boolean shouldUseHttps;
public GRPCConnectionManager(final boolean shouldUseHttps) {
this.shouldUseHttps = shouldUseHttps;
this.port = 0;
if (shouldUseHttps) {
this.certFile = CertificateUtils.getClientCertificateFile();
this.pkeyFile = CertificateUtils.getClientPrivateKeyFile();
this.trustedCasFile = CertificateUtils.getClientTrustedCasFile();
}
}
/**
* Constructor that allows you to specify which port a client should connect to
*
* @param shouldUseHttps Whether to enable TLS
* @param port The port number that client stubs should attempt to connect to
*/
public GRPCConnectionManager(final boolean shouldUseHttps, int port) {
this.shouldUseHttps = shouldUseHttps;
this.port = port;
if (shouldUseHttps) {
this.certFile = CertificateUtils.getClientCertificateFile();
this.pkeyFile = CertificateUtils.getClientPrivateKeyFile();
this.trustedCasFile = CertificateUtils.getClientTrustedCasFile();
}
}
@VisibleForTesting
public ConcurrentMap>
getPerHostChannelMap() {
return perHostChannelMap;
}
@VisibleForTesting
public ConcurrentMap>
getPerHostClientStubMap() {
return perHostClientStubMap;
}
/**
* Gets the client stub(on which the rpcs can be initiated) for a host.
*
* @param remoteHost The host to which we want to make an RPC to.
* @return The stub object.
*/
public InterNodeRpcServiceStub getClientStubForHost(final InstanceDetails remoteHost) {
final AtomicReference stubAtomicReference =
perHostClientStubMap.get(remoteHost.getInstanceId());
if (stubAtomicReference != null) {
return stubAtomicReference.get();
}
return addOrUpdateClientStubForHost(remoteHost);
}
/**
* Builds or updates a stub object for host. Callers: The subscription send thread, the flow
* unit send thread.
*
* @param remoteHost The host to which an RPC needs to be made.
* @return The stub object.
*/
private synchronized InterNodeRpcServiceStub addOrUpdateClientStubForHost(
final InstanceDetails remoteHost) {
final InterNodeRpcServiceStub stub = buildStubForHost(remoteHost);
perHostClientStubMap.computeIfAbsent(
remoteHost.getInstanceId(), s -> new AtomicReference<>());
perHostClientStubMap.get(remoteHost.getInstanceId()).set(stub);
return stub;
}
public void shutdown() {
removeAllStubs();
terminateAllConnections();
}
private ManagedChannel getChannelForHost(final InstanceDetails remoteHost) {
final AtomicReference managedChannelAtomicReference =
perHostChannelMap.get(remoteHost.getInstanceId());
if (managedChannelAtomicReference != null) {
return managedChannelAtomicReference.get();
}
return addOrUpdateChannelForHost(remoteHost);
}
/**
* Builds or updates a channel object to be used by a client stub. Callers: Send flow unit
* thread, send subscription thread.
*
* @param remoteHost The host to which we want to establish a channel to.
* @return a Managed channel object.
*/
private synchronized ManagedChannel addOrUpdateChannelForHost(
final InstanceDetails remoteHost) {
final ManagedChannel channel = buildChannelForHost(remoteHost);
perHostChannelMap.computeIfAbsent(remoteHost.getInstanceId(), s -> new AtomicReference<>());
perHostChannelMap.get(remoteHost.getInstanceId()).set(channel);
return channel;
}
private ManagedChannel buildChannelForHost(final InstanceDetails remoteHost) {
return shouldUseHttps ? buildSecureChannel(remoteHost) : buildInsecureChannel(remoteHost);
}
private int getPortFromHost(final InstanceDetails remoteHost) {
int port = this.port != 0 ? this.port : remoteHost.getGrpcPort();
if (port == -1) {
throw new IllegalArgumentException("Invalid port for grpc: " + port);
}
return port;
}
private ManagedChannel buildInsecureChannel(final InstanceDetails remoteHost) {
return ManagedChannelBuilder.forAddress(
remoteHost.getInstanceIp().toString(), getPortFromHost(remoteHost))
.usePlaintext()
.enableRetry()
.maxRetryAttempts(MAX_RETRY_ATTEMPTS)
.build();
}
private ManagedChannel buildSecureChannel(final InstanceDetails remoteHost) {
try {
SslContextBuilder sslContextBuilder =
GrpcSslContexts.forClient().keyManager(certFile, pkeyFile);
if (trustedCasFile != null) {
sslContextBuilder.trustManager(trustedCasFile);
}
return NettyChannelBuilder.forAddress(
remoteHost.getInstanceIp().toString(), getPortFromHost(remoteHost))
.sslContext(sslContextBuilder.build())
.enableRetry()
.maxRetryAttempts(MAX_RETRY_ATTEMPTS)
.build();
} catch (SSLException e) {
LOG.error("Unable to build an SSL gRPC client.", e);
// Wrap the SSL Exception in a generic RTE and re-throw.
throw new RuntimeException(e);
}
}
private InterNodeRpcServiceStub buildStubForHost(final InstanceDetails remoteHost) {
return InterNodeRpcServiceGrpc.newStub(getChannelForHost(remoteHost));
}
private void removeAllStubs() {
for (Map.Entry> entry :
perHostClientStubMap.entrySet()) {
LOG.debug("Removing client stub for host: {}", entry.getKey());
perHostClientStubMap.remove(entry.getKey());
}
}
private void terminateAllConnections() {
for (Map.Entry> entry :
perHostChannelMap.entrySet()) {
LOG.debug("shutting down connection to host: {}", entry.getKey());
ManagedChannel channel = entry.getValue().get();
channel.shutdownNow();
try {
if (!channel.awaitTermination(1, TimeUnit.MINUTES)) {
StatsCollector.instance()
.logException(StatExceptionCode.GRPC_CHANNEL_CLOSURE_ERROR);
LOG.warn("Unable to close channel gracefully for host: {}", entry.getKey());
}
} catch (InterruptedException e) {
LOG.warn("Channel interrupted while shutting down", e);
channel.shutdownNow();
Thread.currentThread().interrupt();
}
perHostChannelMap.remove(entry.getKey());
}
}
/**
* Removes the stub and the channel object for the host.
*
* @param remoteHost the host to which we want to terminate connection from.
*/
public void terminateConnection(InstanceDetails.Id remoteHost) {
perHostClientStubMap.remove(remoteHost);
perHostChannelMap.remove(remoteHost);
}
}