import base64
import logging
import typing
from redshift_connector.error import InterfaceError
from redshift_connector.plugin.credential_provider_constants import azure_headers
from redshift_connector.plugin.saml_credentials_provider import SamlCredentialsProvider
from redshift_connector.redshift_property import RedshiftProperty
_logger: logging.Logger = logging.getLogger(__name__)
# Class to get SAML Response from Microsoft Azure using OAuth 2.0 API
class AzureCredentialsProvider(SamlCredentialsProvider):
"""
Identity Provider Plugin providing single sign-on access to an Amazon Redshift cluster using Azure,
See `Amazon Redshift docs `_
for setup instructions.
"""
def __init__(self: "AzureCredentialsProvider") -> None:
super().__init__()
self.idp_tenant: typing.Optional[str] = None
self.client_secret: typing.Optional[str] = None
self.client_id: typing.Optional[str] = None
# method to grab the field parameters specified by end user.
# This method adds to it Azure specific parameters.
def add_parameter(self: "AzureCredentialsProvider", info: RedshiftProperty) -> None:
super().add_parameter(info)
# The value of parameter idp_tenant.
self.idp_tenant = info.idp_tenant
# The value of parameter client_secret.
self.client_secret = info.client_secret
# The value of parameter client_id.
self.client_id = info.client_id
# Required method to grab the SAML Response. Used in base class to refresh temporary credentials.
def get_saml_assertion(self: "AzureCredentialsProvider") -> str:
# idp_tenant, client_secret, and client_id are
# all required parameters to be able to authenticate with Microsoft Azure.
# user and password are also required and need to be set to the username and password of the
# Microsoft Azure account that is logging in.
if self.user_name == "" or self.user_name is None:
raise InterfaceError("Missing required property: user_name")
if self.password == "" or self.password is None:
raise InterfaceError("Missing required property: password")
if self.idp_tenant == "" or self.idp_tenant is None:
raise InterfaceError("Missing required property: idp_tenant")
if self.client_secret == "" or self.client_secret is None:
raise InterfaceError("Missing required property: client_secret")
if self.client_id == "" or self.client_id is None:
raise InterfaceError("Missing required property: client_id")
return self.azure_oauth_based_authentication()
# Method to initiate a POST request to grab the SAML Assertion from Microsoft Azure
# and convert it to a SAML Response.
def azure_oauth_based_authentication(self: "AzureCredentialsProvider") -> str:
import requests
# endpoint to connect with Microsoft Azure to get SAML Assertion token
url: str = "https://login.microsoftonline.com/{tenant}/oauth2/token".format(tenant=self.idp_tenant)
_logger.debug("Uri: {}".format(url))
self.validate_url(url)
# headers to pass with POST request
headers: typing.Dict[str, str] = azure_headers
# required parameters to pass in POST body
payload: typing.Dict[str, typing.Optional[str]] = {
"grant_type": "password",
"requested_token_type": "urn:ietf:params:oauth:token-type:saml2",
"username": self.user_name,
"password": self.password,
"client_secret": self.client_secret,
"client_id": self.client_id,
"resource": self.client_id,
}
try:
response: "requests.Response" = requests.post(
url, data=payload, headers=headers, verify=self.do_verify_ssl_cert()
)
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if "response" in vars():
_logger.debug(
"azure_oauth_based_authentication https response: {}".format(response.content) # type: ignore
)
else:
_logger.debug("Azure_oauth_based_authentication could not receive https response due to an error")
_logger.error("Request for authentication from Azure was unsuccessful. {}".format(str(e)))
raise InterfaceError(e)
except requests.exceptions.Timeout as e:
_logger.error("A timeout occurred when requesting authentication from Azure")
raise InterfaceError(e)
except requests.exceptions.TooManyRedirects as e:
_logger.error(
"A error occurred when requesting authentication from Azure. Verify RedshiftProperties are correct"
)
raise InterfaceError(e)
except requests.exceptions.RequestException as e:
_logger.error("A unknown error occurred when requesting authentication from Azure.")
raise InterfaceError(e)
_logger.debug("Azure Oauth authentication response length: {}".format(len(response.text)))
# parse the JSON response to grab access_token field which contains Base64 encoded SAML
# Assertion and decode it
saml_assertion: str = ""
try:
saml_assertion = response.json()["access_token"]
except Exception as e:
_logger.error("Failed to authenticate with Azure. Response from Azure did not include access_token.")
raise InterfaceError(e)
if saml_assertion == "":
raise InterfaceError("Azure access_token is empty")
missing_padding: int = 4 - len(saml_assertion) % 4
if missing_padding:
saml_assertion += "=" * missing_padding
# decode the SAML Assertion to a String to add XML tags to form a SAML Response
decoded_saml_assertion: str = ""
try:
decoded_saml_assertion = str(base64.urlsafe_b64decode(saml_assertion))
except TypeError as e:
_logger.error("Failed to decode saml assertion returned from Azure")
raise InterfaceError(e)
# SAML Response is required to be sent to base class. We need to provide a minimum of:
# 1) samlp:Response XML tag with xmlns:samlp protocol value
# 2) samlp:Status XML tag and samlpStatusCode XML tag with Value indicating Success
# 3) followed by Signed SAML Assertion
saml_response: str = (
''
""
''
""
"{decoded_saml_assertion}"
"".format(decoded_saml_assertion=decoded_saml_assertion[2:-1])
)
# re-encode the SAML Response in Base64 and return this to the base class
saml_response = str(base64.b64encode(saml_response.encode("utf-8")))[2:-1]
return saml_response