#!/bin/bash # Helper script to connect to remote managed instance with SSM and start SSH port forwarding # Every time it generates a new SSH key at ~/.ssh/sagemaker-ssh-gw and transfers the public part # to the instance via S3 by executing the remote SSM command # Syntax: # sm-connect-ssh-proxy set -e INSTANCE_ID="$1" SSH_AUTHORIZED_KEYS="$2" shift shift instance_status=$(aws ssm describe-instance-information --filters Key=InstanceIds,Values="$INSTANCE_ID" --query 'InstanceInformationList[0].PingStatus' --output text) echo "Instance status: $instance_status" if [[ "$instance_status" != "Online" ]]; then echo "Error: Instance is offline." exit 1 fi # TODO: make it possible to override the default (also helps avoid race conditions) SSH_KEY=~/.ssh/sagemaker-ssh-gw echo "Generating $SSH_KEY keypair with ECDSA and uploading public key to $SSH_AUTHORIZED_KEYS" echo 'yes' | ssh-keygen -t ecdsa -q -f "${SSH_KEY}" -N '' >/dev/null aws s3 cp "${SSH_KEY}.pub" "${SSH_AUTHORIZED_KEYS}" CURRENT_REGION=$(aws configure list | grep region | awk '{print $2}') echo "Will use AWS Region: $CURRENT_REGION" PORT_FWD_ARGS=$* AWS_CLI_VERSION=$(aws --version) echo "AWS CLI version (should be v2): $AWS_CLI_VERSION" echo "Running SSM commands at region ${CURRENT_REGION} to copy public key to ${INSTANCE_ID}" send_command=$(aws ssm send-command \ --region "${CURRENT_REGION}" \ --instance-ids "${INSTANCE_ID}" \ --document-name "AWS-RunShellScript" \ --comment "Copy public key for SSH helper" \ --timeout-seconds 30 \ --parameters "commands=[ 'mkdir -p /etc/ssh/authorized_keys.d/', 'aws s3 cp --recursive \"${SSH_AUTHORIZED_KEYS}\" /etc/ssh/authorized_keys.d/', 'find /etc/ssh/authorized_keys.d/', 'cat /etc/ssh/authorized_keys.d/* > /etc/ssh/authorized_keys', 'find /etc/ssh/authorized_keys' ]" \ --no-cli-pager --no-paginate \ --output json) json_value_regexp='s/^[^"]*".*": \"\(.*\)\"[^"]*/\1/' send_command=$(echo "$send_command" | python -m json.tool) command_id=$(echo "$send_command" | grep "CommandId" | sed -e "$json_value_regexp") echo "Got command ID: $command_id" for i in $(seq 1 15); do # Switch to unicode for AWS CLI to properly parse output export LC_CTYPE=en_US.UTF-8 command_output=$(aws ssm get-command-invocation \ --instance-id "${INSTANCE_ID}" \ --command-id "${command_id}" \ --no-cli-pager --no-paginate \ --output json) command_output=$(echo "$command_output" | python -m json.tool) command_status=$(echo "$command_output" | grep '"Status":' | sed -e "$json_value_regexp") output_content=$(echo "$command_output" | grep '"StandardOutputContent":' | sed -e "$json_value_regexp") error_content=$(echo "$command_output" | grep '"StandardErrorContent":' | sed -e "$json_value_regexp") echo "Command status: $command_status" if [[ "$command_status" != "Pending" && "$command_status" != "InProgress" ]]; then echo "Command output: $output_content" if [[ "$error_content" != "" ]]; then echo "Command error: $error_content" fi break fi sleep 1 done if [[ "$command_status" != "Success" ]]; then echo "Error: Command didn't finish successfully in time" exit 2 fi echo "Connecting to $INSTANCE_ID with SSM and starting SSH port forwarding with the args: $PORT_FWD_ARGS" # TODO: remove duplicating message from SSMProxy # We don't use AWS-StartPortForwardingSession feature of SSM here, because we need port forwarding in both directions # with -L and -R parameters of SSH. This is useful for forwarding the PyCharm license server, which needs -R option. # SSM allows only forwarding of ports from the server (equivalent to the -L option). proxy_command="aws ssm start-session\ --reason 'Local user started SageMaker SSH Helper'\ --region '${CURRENT_REGION}'\ --target '${INSTANCE_ID}'\ --document-name AWS-StartSSHSession\ --parameters portNumber=%p" # shellcheck disable=SC2086 ssh -4 -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \ -o ProxyCommand="$proxy_command" \ -o ServerAliveInterval=15 -o ServerAliveCountMax=3 \ -o PasswordAuthentication=no \ -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null \ $PORT_FWD_ARGS "$INSTANCE_ID"