import base64 import concurrent.futures import logging import random import socket import typing from redshift_connector.error import InterfaceError from redshift_connector.plugin.credential_provider_constants import azure_headers from redshift_connector.plugin.saml_credentials_provider import SamlCredentialsProvider from redshift_connector.redshift_property import RedshiftProperty _logger: logging.Logger = logging.getLogger(__name__) # Class to get SAML Response from Microsoft Azure using OAuth 2.0 API class BrowserAzureCredentialsProvider(SamlCredentialsProvider): """ Identity Provider Plugin providing federated access to an Amazon Redshift cluster using Azure browser authentication, See `Amazon Redshift docs `_ for setup instructions. """ def __init__(self: "BrowserAzureCredentialsProvider") -> None: super().__init__() self.idp_tenant: typing.Optional[str] = None self.client_id: typing.Optional[str] = None self.idp_response_timeout: int = 120 self.listen_port: int = 0 self.redirectUri: typing.Optional[str] = None # method to provide a listen socket for authentication def get_listen_socket(self: "BrowserAzureCredentialsProvider") -> socket.socket: s: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("127.0.0.1", 0)) # bind to any free port s.listen() self.listen_port = s.getsockname()[1] return s # method to grab the field parameters specified by end user. # This method adds to it Azure specific parameters. def add_parameter(self: "BrowserAzureCredentialsProvider", info: RedshiftProperty) -> None: super().add_parameter(info) # The value of parameter idp_tenant. self.idp_tenant = info.idp_tenant # The value of parameter client_id. self.client_id = info.client_id self.idp_response_timeout = info.idp_response_timeout _logger.debug("Idp_tenant={}".format(self.idp_tenant)) _logger.debug("Client_id={}".format(self.client_id)) _logger.debug("Idp_response_timeout={}".format(self.idp_response_timeout)) _logger.debug("Listen_port={}".format(self.listen_port)) # Required method to grab the SAML Response. Used in base class to refresh temporary credentials. def get_saml_assertion(self: "BrowserAzureCredentialsProvider") -> str: if self.idp_tenant == "" or self.idp_tenant is None: raise InterfaceError("Missing required property: idp_tenant") if self.client_id == "" or self.client_id is None: raise InterfaceError("Missing required property: client_id") if self.idp_response_timeout < 10: raise InterfaceError("idp_response_timeout should be 10 seconds or greater.") listen_socket: socket.socket = self.get_listen_socket() self.redirectUri = "http://localhost:{port}/redshift/".format(port=self.listen_port) _logger.debug("Listening for connection on port {}".format(self.listen_port)) try: token: str = self.fetch_authorization_token(listen_socket) saml_assertion: str = self.fetch_saml_response(token) except Exception as e: raise e finally: listen_socket.close() _logger.debug("Got SAML assertion") return self.wrap_and_encode_assertion(saml_assertion) # First authentication phase: # Set the state in order to check if the incoming request belongs to the current authentication process. # Start the Socket Server at the {@linkplain BrowserAzureCredentialsProvider#m_listen_port} port. # Open the default browser with the link asking a User to enter the credentials. # Retrieve the SAML Assertion string from the response. Decode it, format, validate and return. def fetch_authorization_token(self: "BrowserAzureCredentialsProvider", listen_socket: socket.socket) -> str: alphabet: str = "abcdefghijklmnopqrstuvwxyz" state: str = "".join(random.sample(alphabet, 10)) try: return_value: str = "" with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(self.run_server, listen_socket, self.idp_response_timeout, state) self.open_browser(state) return_value = future.result() return str(return_value) except socket.error as e: _logger.error("Socket error: %s", e) raise e except Exception as e: _logger.error("Other Exception: %s", e) raise e # Initiates the request to the IDP and gets the response body # Get Base 64 encoded saml assertion from the response body def fetch_saml_response(self: "BrowserAzureCredentialsProvider", token): import requests url: str = "https://login.microsoftonline.com/{tenant}/oauth2/token".format(tenant=self.idp_tenant) # headers to pass with POST request headers: typing.Dict[str, str] = azure_headers self.validate_url(url) # required parameters to pass in POST body payload: typing.Dict[str, typing.Optional[str]] = { "code": token, "requested_token_type": "urn:ietf:params:oauth:token-type:saml2", "grant_type": "authorization_code", "scope": "openid", "resource": self.client_id, "client_id": self.client_id, "redirect_uri": self.redirectUri, } _logger.debug("Uri: {}".format(url)) try: response = requests.post(url, data=payload, headers=headers, verify=self.do_verify_ssl_cert()) response.raise_for_status() except requests.exceptions.HTTPError as e: if "response" in vars(): _logger.debug("Fetch_saml_response https response: {}".format(response.content)) # type: ignore else: _logger.debug("Fetch_saml_response could not receive https response due to an error") _logger.error("Request for authentication from Microsoft was unsuccessful. {}".format(str(e))) raise InterfaceError(e) except requests.exceptions.Timeout as e: _logger.error("A timeout occurred when requesting authentication from Azure") raise InterfaceError(e) except requests.exceptions.TooManyRedirects as e: _logger.error( "A error occurred when requesting authentication from Azure. Verify RedshiftProperties are correct" ) raise InterfaceError(e) except requests.exceptions.RequestException as e: _logger.error("A unknown error occurred when requesting authentication from Azure") raise InterfaceError(e) _logger.debug("Azure authentication response length: {}".format(len(response.text))) try: saml_assertion: str = response.json()["access_token"] except TypeError as e: _logger.error("Failed to decode saml assertion returned from Azure") raise InterfaceError(e) except KeyError as e: _logger.error("Azure access_token was not found in saml assertion") raise InterfaceError(e) except Exception as e: raise InterfaceError(e) if saml_assertion == "": raise InterfaceError("Azure access_token is empty") missing_padding: int = 4 - len(saml_assertion) % 4 if missing_padding: saml_assertion += "=" * missing_padding return str(base64.urlsafe_b64decode(saml_assertion)) # SAML Response is required to be sent to base class. We need to provide a minimum of: # samlp:Response XML tag with xmlns:samlp protocol value # samlp:Status XML tag and samlpStatusCode XML tag with Value indicating Success # followed by Signed SAML Assertion def wrap_and_encode_assertion(self: "BrowserAzureCredentialsProvider", saml_assertion: str) -> str: saml_response: str = ( '' "" '' "" "{saml_assertion}" "".format(saml_assertion=saml_assertion[2:-1]) ) return str(base64.b64encode(saml_response.encode("utf-8")))[2:-1] def run_server( self: "BrowserAzureCredentialsProvider", listen_socket: socket.socket, idp_response_timeout: int, state: str ) -> str: conn, addr = listen_socket.accept() conn.settimeout(float(idp_response_timeout)) size: int = 102400 with conn: while True: part: bytes = conn.recv(size) decoded_part = part.decode() state_idx: int = decoded_part.find("state=") if state_idx > -1: received_state: str = decoded_part[state_idx + 6 : decoded_part.find("&", state_idx)] if received_state != state: raise InterfaceError( "Incoming state {received} does not match the outgoing state {expected}".format( received=received_state, expected=state ) ) code_idx: int = decoded_part.find("code=") if code_idx < 0: raise InterfaceError("No code found") received_code: str = decoded_part[code_idx + 5 : decoded_part.find("&", code_idx)] if received_code == "": raise InterfaceError("No valid code found") conn.send(self.close_window_http_resp()) return received_code # Opens the default browser with the authorization request to the IDP def open_browser(self: "BrowserAzureCredentialsProvider", state: str) -> None: import webbrowser url: str = ( "https://login.microsoftonline.com/{tenant}" "/oauth2/authorize" "?scope=openid" "&response_type=code" "&response_mode=form_post" "&client_id={id}" "&redirect_uri={uri}" "&state={state}".format(tenant=self.idp_tenant, id=self.client_id, uri=self.redirectUri, state=state) ) self.validate_url(url) webbrowser.open(url)