package com.amazon.redshift.plugin; import com.amazonaws.ClientConfiguration; import com.amazonaws.SdkClientException; import com.amazonaws.auth.AWSCredentials; import com.amazonaws.auth.AWSCredentialsProvider; import com.amazonaws.auth.AWSStaticCredentialsProvider; import com.amazonaws.auth.AnonymousAWSCredentials; import com.amazonaws.auth.BasicSessionCredentials; import com.amazonaws.services.securitytoken.AWSSecurityTokenService; import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder; import com.amazonaws.services.securitytoken.model.AssumeRoleWithSAMLRequest; import com.amazonaws.services.securitytoken.model.AssumeRoleWithSAMLResult; import com.amazonaws.services.securitytoken.model.Credentials; import com.amazonaws.util.StringUtils; import com.amazon.redshift.CredentialsHolder; import com.amazon.redshift.CredentialsHolder.IamMetadata; import com.amazon.redshift.IPlugin; import com.amazon.redshift.RedshiftProperty; import com.amazon.redshift.core.IamHelper; import com.amazon.redshift.httpclient.log.IamCustomLogFactory; import com.amazon.redshift.logger.LogLevel; import com.amazon.redshift.logger.RedshiftLogger; import com.amazon.redshift.plugin.utils.RequestUtils; import java.io.ByteArrayInputStream; import java.io.IOException; import java.net.MalformedURLException; import java.net.URI; import java.net.URL; import java.util.ArrayList; import java.util.Collections; import java.util.Date; import java.util.Enumeration; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.ParserConfigurationException; import javax.xml.xpath.XPath; import javax.xml.xpath.XPathConstants; import javax.xml.xpath.XPathExpressionException; import javax.xml.xpath.XPathFactory; import org.apache.commons.codec.binary.Base64; import org.apache.commons.logging.LogFactory; import org.w3c.dom.Document; import org.w3c.dom.Node; import org.w3c.dom.NodeList; import org.xml.sax.SAXException; public abstract class SamlCredentialsProvider extends IdpCredentialsProvider implements IPlugin { protected static final String KEY_IDP_HOST = "idp_host"; private static final String KEY_IDP_PORT = "idp_port"; private static final String KEY_DURATION = "duration"; private static final String KEY_PREFERRED_ROLE = "preferred_role"; protected String m_userName; protected String m_password; protected String m_idpHost; protected int m_idpPort = 443; protected int m_duration; protected String m_preferredRole; protected String m_dbUser; protected String m_dbGroups; protected String m_dbGroupsFilter; protected Boolean m_forceLowercase; protected Boolean m_autoCreate; protected String m_stsEndpoint; protected String m_region; protected Boolean m_disableCache = false; protected Boolean m_groupFederation = false; private static Map m_cache = new HashMap(); private CredentialsHolder m_lastRefreshCredentials; // Used when cache is disable. /** * The custom log factory class. */ private static final Class CUSTOM_LOG_FACTORY_CLASS = IamCustomLogFactory.class; /** * Log properties file name. */ private static final String LOG_PROPERTIES_FILE_NAME = "log-factory.properties"; /** * Log properties file path. */ private static final String LOG_PROPERTIES_FILE_PATH = "META-INF/services/org.apache.commons.logging.LogFactory"; /** * A custom context class loader which allows us to control which LogFactory is loaded. * Our CUSTOM_LOG_FACTORY_CLASS will divert any wire logging to NoOpLogger to suppress wire * messages being logged. */ private static final ClassLoader CONTEXT_CLASS_LOADER = new ClassLoader( SamlCredentialsProvider.class.getClassLoader()) { @Override public Class loadClass(String name) throws ClassNotFoundException { Class clazz = getParent().loadClass(name); if (org.apache.commons.logging.LogFactory.class.isAssignableFrom(clazz)) { return CUSTOM_LOG_FACTORY_CLASS; } return clazz; } @Override public Enumeration getResources(String name) throws IOException { if (LogFactory.FACTORY_PROPERTIES.equals(name)) { // make sure not load any other commons-logging.properties files return Collections.enumeration(Collections.emptyList()); } return super.getResources(name); } @Override public URL getResource(String name) { if (LOG_PROPERTIES_FILE_PATH.equals(name)) { return SamlCredentialsProvider.class.getResource(LOG_PROPERTIES_FILE_NAME); } return super.getResource(name); } }; protected abstract String getSamlAssertion() throws IOException; @Override public void addParameter(String key, String value) { if (RedshiftLogger.isEnable()) m_log.logDebug("key: {0}", key); if (RedshiftProperty.UID.getName().equalsIgnoreCase(key) || RedshiftProperty.USER.getName().equalsIgnoreCase(key)) { m_userName = value; } else if (RedshiftProperty.PWD.getName().equalsIgnoreCase(key) || RedshiftProperty.PASSWORD.getName().equalsIgnoreCase(key)) { m_password = value; } else if (KEY_IDP_HOST.equalsIgnoreCase(key)) { m_idpHost = value; } else if (KEY_IDP_PORT.equalsIgnoreCase(key)) { m_idpPort = Integer.parseInt(value); } else if (KEY_DURATION.equalsIgnoreCase(key)) { m_duration = Integer.parseInt(value); } else if (KEY_PREFERRED_ROLE.equalsIgnoreCase(key)) { m_preferredRole = value; } else if (KEY_SSL_INSECURE.equalsIgnoreCase(key)) { m_sslInsecure = Boolean.parseBoolean(value); } else if (RedshiftProperty.DB_USER.getName().equalsIgnoreCase(key)) { m_dbUser = value; } else if (RedshiftProperty.DB_GROUPS.getName().equalsIgnoreCase(key)) { m_dbGroups = value; } else if (RedshiftProperty.DB_GROUPS_FILTER.getName().equalsIgnoreCase(key)) { m_dbGroupsFilter = value; } else if (RedshiftProperty.FORCE_LOWERCASE.getName().equalsIgnoreCase(key)) { m_forceLowercase = Boolean.valueOf(value); } else if (RedshiftProperty.USER_AUTOCREATE.getName().equalsIgnoreCase(key)) { m_autoCreate = Boolean.valueOf(value); } else if (RedshiftProperty.AWS_REGION.getName().equalsIgnoreCase(key)) { m_region = value; } else if (RedshiftProperty.STS_ENDPOINT_URL.getName().equalsIgnoreCase(key)) { m_stsEndpoint = value; } else if (RedshiftProperty.IAM_DISABLE_CACHE.getName().equalsIgnoreCase(key)) { m_disableCache = Boolean.valueOf(value); } } @Override public void setLogger(RedshiftLogger log) { m_log = log; } @Override public int getSubType() { return IamHelper.SAML_PLUGIN; } @Override public CredentialsHolder getCredentials() { CredentialsHolder credentials = null; if(!m_disableCache) { String key = getCacheKey(); credentials = m_cache.get(key); } if (credentials == null || credentials.isExpired()) { if(RedshiftLogger.isEnable()) m_log.logInfo("SAML getCredentials NOT from cache"); synchronized(this) { refresh(); if(m_disableCache) { credentials = m_lastRefreshCredentials; m_lastRefreshCredentials = null; } } } else { credentials.setRefresh(false); if(RedshiftLogger.isEnable()) m_log.logInfo("SAML getCredentials from cache"); } if(!m_disableCache) { // if the SAML response has dbUser argument, it will be picked up at this point. credentials = m_cache.get(getCacheKey()); } // if dbUser argument has been passed in the connection string, add it to metadata. if (!StringUtils.isNullOrEmpty(m_dbUser)) { credentials.getThisMetadata().setDbUser(this.m_dbUser); } if (credentials == null) { throw new SdkClientException("Unable to load AWS credentials from ADFS"); } if(RedshiftLogger.isEnable()) { Date now = new Date(); m_log.logInfo(now + ": Using entry for SamlCredentialsProvider.getCredentials cache with expiration " + credentials.getExpiration()); } return credentials; } @Override public void refresh() { // Get the current thread and set the context loader with our custom load class method. Thread currentThread = Thread.currentThread(); ClassLoader cl = currentThread.getContextClassLoader(); Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER); try { String samlAssertion = getSamlAssertion(); if (RedshiftLogger.isEnable()) m_log.logDebug( String.format("SAML assertion: %s", samlAssertion)); final Pattern SAML_PROVIDER_PATTERN = Pattern.compile("arn:aws[-a-z]*:iam::\\d*:saml-provider/\\S+"); final Pattern ROLE_PATTERN = Pattern.compile("arn:aws[-a-z]*:iam::\\d*:role/\\S+"); Document doc = parse(Base64.decodeBase64(samlAssertion)); XPath xPath = XPathFactory.newInstance().newXPath(); String expression = "//*[local-name()='Attribute'][@Name='https://aws.amazon.com/SAML/Attributes/Role']/*[local-name()='AttributeValue']/text()"; NodeList nodeList = (NodeList) xPath.compile(expression) .evaluate(doc, XPathConstants.NODESET); Map roles = new HashMap(); if (nodeList != null) { for (int i = 0; i < nodeList.getLength(); ++i) { Node node = nodeList.item(i); String value = node.getNodeValue(); String[] arns = value.split(","); if (arns.length >= 2) { String provider = null; String role = null; for (String arn : arns) { Matcher providerMatcher = SAML_PROVIDER_PATTERN.matcher(arn); if (providerMatcher.find()) { provider = providerMatcher.group(0); continue; } Matcher roleMatcher = ROLE_PATTERN.matcher(arn); if (roleMatcher.find()) { role = roleMatcher.group(0); } } if (!StringUtils.isNullOrEmpty(role) && !StringUtils.isNullOrEmpty(provider)) { roles.put(role, provider); } } } } if (roles.isEmpty()) { throw new SdkClientException("No role found in SamlAssertion: " + samlAssertion); } String roleArn; String principal; if (m_preferredRole != null) { roleArn = m_preferredRole; principal = roles.get(m_preferredRole); if (principal == null) { throw new SdkClientException("Preferred role not found in SamlAssertion: " + samlAssertion); } } else { Map.Entry entry = roles.entrySet().iterator().next(); roleArn = entry.getKey(); principal = entry.getValue(); } AssumeRoleWithSAMLRequest samlRequest = new AssumeRoleWithSAMLRequest(); samlRequest.setSAMLAssertion(samlAssertion); samlRequest.setRoleArn(roleArn); samlRequest.setPrincipalArn(principal); if (m_duration > 0) { samlRequest.setDurationSeconds(m_duration); } AWSCredentialsProvider p = new AWSStaticCredentialsProvider(new AnonymousAWSCredentials()); AWSSecurityTokenServiceClientBuilder builder = AWSSecurityTokenServiceClientBuilder.standard(); ClientConfiguration config = null; builder.withClientConfiguration(config); AWSSecurityTokenService stsSvc = RequestUtils.buildSts(m_stsEndpoint, m_region, builder, p, m_log); AssumeRoleWithSAMLResult result = stsSvc.assumeRoleWithSAML(samlRequest); Credentials cred = result.getCredentials(); Date expiration = cred.getExpiration(); AWSCredentials c = new BasicSessionCredentials(cred.getAccessKeyId(), cred.getSecretAccessKey(), cred.getSessionToken()); CredentialsHolder credentials = CredentialsHolder.newInstance(c, expiration); credentials.setMetadata(readMetadata(doc)); credentials.setRefresh(true); if(!m_disableCache) m_cache.put(getCacheKey(), credentials); else m_lastRefreshCredentials = credentials; } catch (IOException e) { if (RedshiftLogger.isEnable()) m_log.logError(e); throw new SdkClientException("SAML error: " + e.getMessage(), e); } catch (SAXException e) { if (RedshiftLogger.isEnable()) m_log.logError(e); throw new SdkClientException("SAML error: " + e.getMessage(), e); } catch (ParserConfigurationException e) { if (RedshiftLogger.isEnable()) m_log.logError(e); throw new SdkClientException("SAML error: " + e.getMessage(), e); } catch (XPathExpressionException e) { if (RedshiftLogger.isEnable()) m_log.logError(e); throw new SdkClientException("SAML error: " + e.getMessage(), e); } catch (Exception e) { if (RedshiftLogger.isEnable()) m_log.logError(e); throw new SdkClientException("SAML error: " + e.getMessage(), e); } finally { currentThread.setContextClassLoader(cl); } } @Override public String getPluginSpecificCacheKey() { // Override this in each derived plugin such as Azure, Browser, Okta, Ping etc. return ""; } @Override public String getIdpToken() { String samlAssertion = null; // Get the current thread and set the context loader with our custom load class method. Thread currentThread = Thread.currentThread(); ClassLoader cl = currentThread.getContextClassLoader(); Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER); try { samlAssertion = getSamlAssertion(); if (RedshiftLogger.isEnable()) m_log.logDebug( String.format("SAML assertion: %s", samlAssertion)); } catch (IOException e) { if (RedshiftLogger.isEnable()) m_log.logError(e); throw new SdkClientException("SAML error: " + e.getMessage(), e); } catch (Exception e) { if (RedshiftLogger.isEnable()) m_log.logError(e); throw new SdkClientException("SAML error: " + e.getMessage(), e); } finally { currentThread.setContextClassLoader(cl); } return samlAssertion; } @Override public void setGroupFederation(boolean groupFederation) { m_groupFederation = groupFederation; } @Override public String getCacheKey() { String pluginSpecificKey = getPluginSpecificCacheKey(); return m_userName + m_password + m_idpHost + m_idpPort + m_duration + m_preferredRole + pluginSpecificKey; } private IamMetadata readMetadata(Document doc) throws XPathExpressionException { IamMetadata metadata = new IamMetadata(); XPath xPath = XPathFactory.newInstance().newXPath(); List attributeValues = GetSAMLAttributeValues(xPath, doc, "https://redshift.amazon.com/SAML/Attributes/AllowDbUserOverride"); if (!attributeValues.isEmpty()) { metadata.setAllowDbUserOverride(Boolean.valueOf(attributeValues.get(0))); } attributeValues = GetSAMLAttributeValues(xPath, doc, "https://redshift.amazon.com/SAML/Attributes/DbUser"); if (!attributeValues.isEmpty()) { metadata.setSamlDbUser(attributeValues.get(0)); } else { attributeValues = GetSAMLAttributeValues(xPath, doc, "https://aws.amazon.com/SAML/Attributes/RoleSessionName"); if (!attributeValues.isEmpty()) { metadata.setSamlDbUser(attributeValues.get(0)); } } attributeValues = GetSAMLAttributeValues(xPath, doc, "https://redshift.amazon.com/SAML/Attributes/AutoCreate"); if (!attributeValues.isEmpty()) { metadata.setAutoCreate(Boolean.valueOf(attributeValues.get(0))); } attributeValues = GetSAMLAttributeValues(xPath, doc, "https://redshift.amazon.com/SAML/Attributes/DbGroups"); if (!attributeValues.isEmpty()) { attributeValues = filterOutGroups(attributeValues); if (!attributeValues.isEmpty()) { StringBuilder sb = new StringBuilder(); for (String value : attributeValues) { if (sb.length() > 0) { sb.append(','); } sb.append(value); } metadata.setDbGroups(sb.toString()); } } attributeValues = GetSAMLAttributeValues(xPath, doc, "https://redshift.amazon.com/SAML/Attributes/ForceLowercase"); if (!attributeValues.isEmpty()) { metadata.setForceLowercase(Boolean.valueOf(attributeValues.get(0))); } return metadata; } /** * Method removes all groups from given lists matching {@link m_dbGroupsFilter} * regex. * @param attributeValues in * @return attributeValues filtered */ private List filterOutGroups(List attributeValues) { if ( m_dbGroupsFilter != null ) { final Pattern groupsFilter = Pattern.compile(m_dbGroupsFilter); List ret = new ArrayList<>(); for (String attributeValue : attributeValues) { m_log.logDebug("Check group {0} with regexp {1}", attributeValue, m_dbGroupsFilter); if (!groupsFilter.matcher(attributeValue).matches()) { m_log.logDebug("Add {0} to dbgroups", attributeValue); ret.add(attributeValue); } } return ret; } else { return attributeValues; } } private static Document parse(byte[] samlAssertion) throws IOException, SAXException, ParserConfigurationException { DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); factory.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true); factory.setXIncludeAware(false); factory.setExpandEntityReferences(false); factory.setFeature("http://xml.org/sax/features/external-parameter-entities", false); factory.setFeature("http://xml.org/sax/features/external-general-entities", false); DocumentBuilder db = factory.newDocumentBuilder(); return db.parse(new ByteArrayInputStream(samlAssertion)); } private static List GetSAMLAttributeValues(XPath xPath, Document doc, String attributeName) throws XPathExpressionException { String expression = String.format("//Attribute[@Name='%s']/AttributeValue/text()", attributeName); NodeList nodeList = (NodeList) xPath.compile(expression).evaluate(doc, XPathConstants.NODESET); if (null == nodeList || nodeList.getLength() == 0) { return Collections.emptyList(); } List attributeValues = new ArrayList(nodeList.getLength()); for (int i = 0; i < nodeList.getLength(); ++i) { Node node = nodeList.item(i); attributeValues.add(node.getNodeValue()); } return attributeValues; } protected List getInputTagsfromHTML(String body) { Set distinctInputTags = new HashSet<>(); List inputTags = new ArrayList(); Pattern inputTagPattern = Pattern.compile("", Pattern.DOTALL); Matcher inputTagMatcher = inputTagPattern.matcher(body); while (inputTagMatcher.find()) { String tag = inputTagMatcher.group(0); String tagNameLower = getValueByKey(tag, "name").toLowerCase(); if (!tagNameLower.isEmpty() && distinctInputTags.add(tagNameLower)) { inputTags.add(tag); } } return inputTags; } protected String getFormAction(String body) { Pattern pattern = Pattern.compile("'); i += 4; } else { sb.append(c); ++i; } } return sb.toString(); } protected void checkRequiredParameters() throws IOException { if (StringUtils.isNullOrEmpty(m_userName)) { throw new IOException("Missing required property: " + RedshiftProperty.USER.getName()); } if (StringUtils.isNullOrEmpty(m_password)) { throw new IOException("Missing required property: " + RedshiftProperty.PASSWORD.getName()); } if (StringUtils.isNullOrEmpty(m_idpHost)) { throw new IOException("Missing required property: " + KEY_IDP_HOST); } } protected boolean isText(String inputTag) { String typeVal = getValueByKey(inputTag, "type"); if(typeVal == null || typeVal.length() == 0) { typeVal = getValueByKeyWithoutQuotesAndValueInSingleQuote(inputTag, "type"); } return "text".equals(typeVal); } protected boolean isPassword(String inputTag) { String typeVal = getValueByKey(inputTag, "type"); if(typeVal == null || typeVal.length() == 0) { typeVal = getValueByKeyWithoutQuotesAndValueInSingleQuote(inputTag, "type"); } return "password".equals(typeVal); } }