/* * Copyright <2022> Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. * A copy of the License is located at * * http://www.apache.org/licenses/LICENSE-2.0 * * or in the "license" file accompanying this file. This file is distributed * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either * express or implied. See the License for the specific language governing * permissions and limitations under the License. * */ package software.amazon.documentdb.jdbc.sshtunnel; import com.google.common.annotations.VisibleForTesting; import com.google.common.hash.Hashing; import com.jcraft.jsch.HostKey; import com.jcraft.jsch.HostKeyRepository; import com.jcraft.jsch.JSch; import com.jcraft.jsch.JSchException; import com.jcraft.jsch.Session; import lombok.AllArgsConstructor; import lombok.Getter; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.checkerframework.checker.nullness.qual.NonNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.documentdb.jdbc.DocumentDbConnectionProperties; import software.amazon.documentdb.jdbc.common.utilities.SqlError; import software.amazon.documentdb.jdbc.common.utilities.SqlState; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.sql.SQLException; import java.util.Arrays; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.getPath; import static software.amazon.documentdb.jdbc.DocumentDbConnectionProperties.isNullOrWhitespace; /** * Provides a single-instance SSH Tunnel server. *

* Use the {@link #builder(String, String, String, String)} method to instantiate * a new {@link DocumentDbSshTunnelServerBuilder} object. Set the properties as needed, * then call the build() method. */ public final class DocumentDbSshTunnelServer implements AutoCloseable { public static final String SSH_KNOWN_HOSTS_FILE = "~/.ssh/known_hosts"; public static final String STRICT_HOST_KEY_CHECKING = "StrictHostKeyChecking"; public static final String HASH_KNOWN_HOSTS = "HashKnownHosts"; public static final String SERVER_HOST_KEY = "server_host_key"; public static final String YES = "yes"; public static final String NO = "no"; public static final String LOCALHOST = "localhost"; public static final int DEFAULT_DOCUMENTDB_PORT = 27017; public static final int DEFAULT_SSH_PORT = 22; private static final Logger LOGGER = LoggerFactory.getLogger(DocumentDbSshTunnelServer.class); public static final int DEFAULT_CLOSE_DELAY_MS = 30000; private final Object mutex = new Object(); private final AtomicLong clientCount = new AtomicLong(0); private final String sshUser; private final String sshHostname; private final String sshPrivateKeyFile; private final String sshPrivateKeyPassphrase; private final boolean sshStrictHostKeyChecking; private final String sshKnownHostsFile; private final String remoteHostname; private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); private DocumentDbSshTunnelServer.SshPortForwardingSession session = null; private ScheduledFuture scheduledFuture = null; private long closeDelayMS = DEFAULT_CLOSE_DELAY_MS; private DocumentDbSshTunnelServer(final DocumentDbSshTunnelServerBuilder builder) { this.sshUser = builder.sshUser; this.sshHostname = builder.sshHostname; this.sshPrivateKeyFile = builder.sshPrivateKeyFile; this.remoteHostname = builder.sshRemoteHostname; this.sshPrivateKeyPassphrase = builder.sshPrivateKeyPassphrase; this.sshStrictHostKeyChecking = builder.sshStrictHostKeyChecking; this.sshKnownHostsFile = builder.sshKnownHostsFile; LOGGER.debug("sshUser='{}' sshHostname='{}' sshPrivateKeyFile='{}' remoteHostname'{}" + " sshPrivateKeyPassphrase='{}' sshStrictHostKeyChecking='{}' sshKnownHostsFile='{}'", this.sshUser, this.sshHostname, this.sshPrivateKeyFile, this.remoteHostname, this.sshPrivateKeyPassphrase, this.sshStrictHostKeyChecking, this.sshKnownHostsFile ); } /** * Gets the hash string for the SSH properties provided. * * @param sshUser the username credential for the SSH tunnel. * @param sshHostname the hostname (or IP address) for the SSH tunnel. * @param sshPrivateKeyFile the path to the private key file. * * @return a String value representing the hash of the given properties. */ static String getHashString( final String sshUser, final String sshHostname, final String sshPrivateKeyFile, final String remoteHostname) { final String sshPropertiesString = sshUser + "-" + sshHostname + "-" + sshPrivateKeyFile + remoteHostname; return Hashing.sha256() .hashString(sshPropertiesString, StandardCharsets.UTF_8) .toString(); } /** * Initializes the SSH session and creates a port forwarding tunnel. * * @param connectionProperties the {@link DocumentDbConnectionProperties} connection properties. * @return a {@link Session} session. This session must be closed by calling the * {@link Session#disconnect()} method. * @throws SQLException if unable to create SSH session or create the port forwarding tunnel. */ public static SshPortForwardingSession createSshTunnel( final DocumentDbConnectionProperties connectionProperties) throws SQLException { validateSshPrivateKeyFile(connectionProperties); LOGGER.debug("Internal SSH tunnel starting."); try { final JSch jSch = new JSch(); addIdentity(connectionProperties, jSch); final Session session = createSession(connectionProperties, jSch); connectSession(connectionProperties, jSch, session); final SshPortForwardingSession portForwardingSession = getPortForwardingSession( connectionProperties, session); LOGGER.debug("Internal SSH tunnel started on local port '{}'.", portForwardingSession.getLocalPort()); LOGGER.debug("Internal SSH tunnel started."); return portForwardingSession; } catch (Exception e) { throw logException(e); } } private static SshPortForwardingSession getPortForwardingSession( final DocumentDbConnectionProperties connectionProperties, final Session session) throws JSchException { final Pair clusterHostAndPort = getHostAndPort( connectionProperties.getHostname(), DEFAULT_DOCUMENTDB_PORT); final int localPort = session.setPortForwardingL( LOCALHOST, 0, clusterHostAndPort.getLeft(), clusterHostAndPort.getRight()); return new SshPortForwardingSession(session, localPort); } private static Pair getHostAndPort( final String hostname, final int defaultPort) { final String clusterHost; final int clusterPort; final int portSeparatorIndex = hostname.indexOf(':'); if (portSeparatorIndex >= 0) { clusterHost = hostname.substring(0, portSeparatorIndex); clusterPort = Integer.parseInt( hostname.substring(portSeparatorIndex + 1)); } else { clusterHost = hostname; clusterPort = defaultPort; } return new ImmutablePair<>(clusterHost, clusterPort); } private static void connectSession( final DocumentDbConnectionProperties connectionProperties, final JSch jSch, final Session session) throws SQLException { setSecurityConfig(connectionProperties, jSch, session); try { session.connect(); } catch (JSchException e) { throw logException(e); } } private static void addIdentity( final DocumentDbConnectionProperties connectionProperties, final JSch jSch) throws JSchException { final String privateKeyFileName = getPath(connectionProperties.getSshPrivateKeyFile(), DocumentDbConnectionProperties.getDocumentDbSearchPaths()).toString(); LOGGER.debug("SSH private key file resolved to '{}'.", privateKeyFileName); // If passPhrase protected, will need to provide this, too. final String passPhrase = !isNullOrWhitespace(connectionProperties.getSshPrivateKeyPassphrase()) ? connectionProperties.getSshPrivateKeyPassphrase() : null; jSch.addIdentity(privateKeyFileName, passPhrase); } private static Session createSession( final DocumentDbConnectionProperties connectionProperties, final JSch jSch) throws SQLException { final String sshUsername = connectionProperties.getSshUser(); final Pair sshHostAndPort = getHostAndPort( connectionProperties.getSshHostname(), DEFAULT_SSH_PORT); setKnownHostsFile(connectionProperties, jSch); try { return jSch.getSession(sshUsername, sshHostAndPort.getLeft(), sshHostAndPort.getRight()); } catch (JSchException e) { throw logException(e); } } private static void setSecurityConfig( final DocumentDbConnectionProperties connectionProperties, final JSch jSch, final Session session) { if (!connectionProperties.getSshStrictHostKeyChecking()) { session.setConfig(STRICT_HOST_KEY_CHECKING, NO); return; } setHostKeyType(connectionProperties, jSch, session); } private static void setHostKeyType( final DocumentDbConnectionProperties connectionProperties, final JSch jSch, final Session session) { final HostKeyRepository keyRepository = jSch.getHostKeyRepository(); final HostKey[] hostKeys = keyRepository.getHostKey(); final Pair sshHostAndPort = getHostAndPort( connectionProperties.getSshHostname(), DEFAULT_SSH_PORT); final HostKey hostKey = Arrays.stream(hostKeys) .filter(hk -> hk.getHost().equals(sshHostAndPort.getLeft())) .findFirst().orElse(null); // This will ensure a match between how the host key was hashed in the known_hosts file. final String hostKeyType = (hostKey != null) ? hostKey.getType() : null; // Append the hash algorithm if (hostKeyType != null) { session.setConfig(SERVER_HOST_KEY, session.getConfig(SERVER_HOST_KEY) + "," + hostKeyType); } // The default behaviour of `ssh-keygen` is to hash known hosts keys session.setConfig(HASH_KNOWN_HOSTS, YES); } private static void setKnownHostsFile( final DocumentDbConnectionProperties connectionProperties, final JSch jSch) throws SQLException { if (!connectionProperties.getSshStrictHostKeyChecking()) { return; } final String knownHostsFilename; knownHostsFilename = getSshKnownHostsFilename(connectionProperties); try { jSch.setKnownHosts(knownHostsFilename); } catch (JSchException e) { throw logException(e); } } private static SQLException logException(final T e) { LOGGER.error(e.getMessage(), e); if (e instanceof SQLException) { return (SQLException) e; } return new SQLException(e.getMessage(), e); } /** * Gets the SSH tunnel service listening port. A value of zero indicates that the * SSH tunnel service is not running. * * @return A port number that the SSH tunnel service is listening on. */ public int getServiceListeningPort() { return session != null ? session.getLocalPort() : 0; } @Override public void close() { synchronized (mutex) { if (session != null) { LOGGER.debug("Internal SSH Tunnel is stopping."); session.getSession().disconnect(); session = null; LOGGER.debug("Internal SSH Tunnel is stopped."); } } } /** * Adds a client to the reference count for this server. If this is the first client, the server * ensures that an SSH Tunnel service is started. * * @throws SQLException When an error occurs trying to start the SSH Tunnel service. */ public void addClient() throws SQLException { // Needs to be synchronized in a single process synchronized (mutex) { cancelScheduledFutureClose(); clientCount.incrementAndGet(); if (session != null && session.getLocalPort() != 0) { return; } validateLocalSshFilesExists(); session = createSshTunnel(getConnectionProperties()); } } /** * Removes a client from the reference count for this server. If the reference count reaches zero, then * the serve attempt to stop the SSH Tunnel service. * * @throws SQLException When an error occur attempting shutdown of the service process. */ public void removeClient() throws SQLException { synchronized (mutex) { // Takes advantage of OR to only decrement if greater than zero. if (clientCount.get() <= 0 || clientCount.decrementAndGet() > 0) { return; } closeSession(); } } /** * Closes the SSH tunnel session. If a close delay is given, delay the * close until that time has passed. * * @throws SQLException In the case the task is interrupted. */ private void closeSession() throws SQLException { cancelScheduledFutureClose(); // Delay the close, if indicated. final long delayMS = getCloseDelayMS(); if (delayMS <= 0) { close(); } else { LOGGER.debug("Close timer is being scheduled."); scheduledFuture = scheduler.schedule(getCloseTimerTask(), delayMS, TimeUnit.MILLISECONDS); } } /** * Gets the {@link Runnable} task to close the SSH tunnel session. * * @return the task to close the SSH tunnel session. */ private Runnable getCloseTimerTask() { return () -> { try { close(); } catch (Exception e) { // Ignore exception on close. LOGGER.warn(e.getMessage(), e); } }; } /** * Cancels the scheduled future to close the SSH tunnel session in the case a new client gets added before * the close occurs. * * @throws SQLException If interrupted during sleep. */ private void cancelScheduledFutureClose() throws SQLException { synchronized (mutex) { if (scheduledFuture != null) { LOGGER.debug("Close timer is being cancelled."); while (!scheduledFuture.isDone()) { scheduledFuture.cancel(false); try { TimeUnit.MILLISECONDS.sleep(10); } catch (InterruptedException e) { throw new SQLException(e.getMessage(), e); } } } scheduledFuture = null; } } @VisibleForTesting long getCloseDelayMS() { return closeDelayMS; } @VisibleForTesting void setCloseDelayMS(final long closeDelayMS) { this.closeDelayMS = closeDelayMS > 0 ? closeDelayMS : 0; } /** * Gets the number of clients using the server. * * @return The number of clients using the server. */ @VisibleForTesting long getClientCount() { synchronized (mutex) { return clientCount.get(); } } /** * Checks the state of the SSH tunnel service. * * @return Returns true if the SSH tunnel service is running. */ public boolean isAlive() { return session != null; } /** * Factory method for the {@link DocumentDbSshTunnelServerBuilder} class. * * @param user the SSH tunnel username. * @param hostname the SSH tunnel hostname. * @param privateKeyFile the SSH tunnel private key file path. * @param remoteHostname the hostname of the remote server. * * @return a new {@link DocumentDbSshTunnelServerBuilder} instance. */ public static DocumentDbSshTunnelServerBuilder builder( final String user, final String hostname, final String privateKeyFile, final String remoteHostname) { return new DocumentDbSshTunnelServerBuilder(user, hostname, privateKeyFile, remoteHostname); } /** * The {@link DocumentDbSshTunnelServer} builder class. * A call to the {@link #build()} method returns the single instance with * the matching SSH tunnel properties. */ public static class DocumentDbSshTunnelServerBuilder { private final String sshUser; private final String sshHostname; private final String sshPrivateKeyFile; private final String sshRemoteHostname; private String sshPrivateKeyPassphrase = null; private boolean sshStrictHostKeyChecking = true; private String sshKnownHostsFile = null; private static final ConcurrentMap SSH_TUNNEL_MAP = new ConcurrentHashMap<>(); /** * A builder class for the DocumentDbSshTunnelServer. * * @param sshUser the SSH tunnel username. * @param sshHostname the SSH tunnel hostname. * @param sshPrivateKeyFile the SSH tunnel private key file path. * @param sshRemoteHostname the hostname of the remote server. */ DocumentDbSshTunnelServerBuilder( final String sshUser, final String sshHostname, final String sshPrivateKeyFile, final String sshRemoteHostname) { this.sshUser = sshUser; this.sshHostname = sshHostname; this.sshPrivateKeyFile = sshPrivateKeyFile; this.sshRemoteHostname = sshRemoteHostname; } /** * Sets the private key passphrase. * * @param sshPrivateKeyPassphrase the private key passphrase. * @return the current instance of the builder. */ public DocumentDbSshTunnelServerBuilder sshPrivateKeyPassphrase(final String sshPrivateKeyPassphrase) { this.sshPrivateKeyPassphrase = sshPrivateKeyPassphrase; return this; } /** * Sets the strict host key checking option. * * @param sshStrictHostKeyChecking indicator of whether to set the strict host key checking option. * @return the current instance of the builder. */ public DocumentDbSshTunnelServerBuilder sshStrictHostKeyChecking(final boolean sshStrictHostKeyChecking) { this.sshStrictHostKeyChecking = sshStrictHostKeyChecking; return this; } /** * Sets the known hosts file property. * * @param sshKnownHostsFile the file path to the known hosts file. * * @return the current instance of the builder. */ public DocumentDbSshTunnelServerBuilder sshKnownHostsFile(final String sshKnownHostsFile) { this.sshKnownHostsFile = sshKnownHostsFile; return this; } /** * Builds a DocumentDbSshTunnelServer from the given properties. * * @return a new instance of DocumentDbSshTunnelServer. */ public DocumentDbSshTunnelServer build() { final String hashString = getHashString( this.sshUser, this.sshHostname, this.sshPrivateKeyFile, this.sshRemoteHostname ); // Returns single instance of server for the hashed properties. return SSH_TUNNEL_MAP.computeIfAbsent( hashString, key -> new DocumentDbSshTunnelServer(this) ); } } @NonNull private DocumentDbConnectionProperties getConnectionProperties() { final DocumentDbConnectionProperties connectionProperties = new DocumentDbConnectionProperties(); connectionProperties.setHostname(remoteHostname); connectionProperties.setSshUser(sshUser); connectionProperties.setSshHostname(sshHostname); connectionProperties.setSshPrivateKeyFile(sshPrivateKeyFile); connectionProperties.setSshStrictHostKeyChecking(String.valueOf(sshStrictHostKeyChecking)); if (sshPrivateKeyPassphrase != null) { connectionProperties.setSshPrivateKeyPassphrase(sshPrivateKeyPassphrase); } if (sshKnownHostsFile != null) { connectionProperties.setSshKnownHostsFile(sshKnownHostsFile); } return connectionProperties; } private void validateLocalSshFilesExists() throws SQLException { final DocumentDbConnectionProperties connectionProperties = getConnectionProperties(); validateSshPrivateKeyFile(connectionProperties); getSshKnownHostsFilename(connectionProperties); } static void validateSshPrivateKeyFile(final DocumentDbConnectionProperties connectionProperties) throws SQLException { if (!connectionProperties.isSshPrivateKeyFileExists()) { throw SqlError.createSQLException( LOGGER, SqlState.CONNECTION_EXCEPTION, SqlError.SSH_PRIVATE_KEY_FILE_NOT_FOUND, connectionProperties.getSshPrivateKeyFile()); } } static String getSshKnownHostsFilename(final DocumentDbConnectionProperties connectionProperties) throws SQLException { final String knowHostsFilename; if (!isNullOrWhitespace(connectionProperties.getSshKnownHostsFile())) { final Path knownHostsPath = getPath(connectionProperties.getSshKnownHostsFile()); validateSshKnownHostsFile(connectionProperties, knownHostsPath); knowHostsFilename = knownHostsPath.toString(); } else { knowHostsFilename = getPath(SSH_KNOWN_HOSTS_FILE).toString(); } return knowHostsFilename; } private static void validateSshKnownHostsFile( final DocumentDbConnectionProperties connectionProperties, final Path knownHostsPath) throws SQLException { if (!Files.exists(knownHostsPath)) { throw SqlError.createSQLException( LOGGER, SqlState.INVALID_PARAMETER_VALUE, SqlError.KNOWN_HOSTS_FILE_NOT_FOUND, connectionProperties.getSshKnownHostsFile()); } } /** * Container for the SSH port forwarding tunnel session. */ @Getter @AllArgsConstructor static class SshPortForwardingSession { /** * Gets the SSH session. */ private final Session session; /** * Gets the local port for the port forwarding tunnel. */ private final int localPort; } }