From c8da5bf5323d3a49802fd8b3265220f71ad9b82f Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Fri, 22 Dec 2017 15:51:43 +0000 Subject: [PATCH 01/13] Implement (most of the) MFA support for Okta. --- awsprocesscreds/saml.py | 209 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 1665f51..36d3f27 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -1,3 +1,5 @@ +from sys import version_info + import base64 import getpass import logging @@ -216,6 +218,189 @@ def _get_value_of_first_tag(self, root, tag, attr, trait): class OktaAuthenticator(GenericFormsBasedAuthenticator): _AUTH_URL = '/api/v1/authn' + _ERROR_AUTH_CANCELLED = ( + 'Authentication cancelled' + ) + + _ERROR_LOCKED_OUT = ( + "You are locked out of your Okta account. Go to %s to unlock it." + ) + + _ERROR_PASSWORD_EXPIRED = ( + "Your password has expired. Go to %s to change it." + ) + + _ERROR_MFA_ENROLL = ( + "You need to enroll a MFA first." + ) + + _MSG_AUTH_CODE = ( + "Authentication code (RETURN to cancel): " + ) + + _MSG_ANSWER = ( + "Answer (RETURN to cancel): " + ) + + _MSG_SMS_CODE = ( + "Authentication code (RETURN to cancel, " + "'RESEND' to get new code sent): " + ) + + def get_response(self, prompt): + py3 = version_info[0] > 2 + if py3: + response = input(prompt) + else: + response = raw_input(prompt) + if response == "": + raise SAMLError(self._ERROR_AUTH_CANCELLED) + return response + + def get_assertion_from_response(self, endpoint, parsed): + session_token = parsed['sessionToken'] + saml_url = endpoint + '?sessionToken=%s' % session_token + response = self._requests_session.get(saml_url) + logger.info( + 'Received HTTP response of status code: %s', response.status_code) + r = self._extract_saml_assertion_from_response(response.text) + logger.info( + 'Received the following SAML assertion: \n%s', r, + extra={'is_saml_assertion': True} + ) + return r + + def process_mfa_totp(self, endpoint, url, statetoken): + while True: + response = self.get_response(self._MSG_AUTH_CODE) + totp_response = self._requests_session.post( + url, + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + data=json.dumps({'stateToken': statetoken, + 'passCode': response}) + ) + totp_parsed = json.loads(totp_response.text) + if totp_response.status_code == 200: + return self.get_assertion_from_response(endpoint, totp_parsed) + elif totp_response.status_code >= 400: + print totp_parsed["errorCauses"][0]["errorSummary"] + + def process_mfa_push(self, endpoint, url, statetoken): + print "Waiting for result of push notification ..." + while True: + totp_response = self._requests_session.post( + url, + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + data=json.dumps({'stateToken': statetoken}) + ) + totp_parsed = json.loads(totp_response.text) + if totp_parsed["status"] == "SUCCESS": + return self.get_assertion_from_response(endpoint, totp_parsed) + elif totp_parsed["factorResult"] != "WAITING": + raise SAMLError(self._ERROR_AUTH_CANCELLED) + + def process_mfa_security_question(self, endpoint, url, statetoken): + while True: + response = self.get_response(self._MSG_ANSWER) + totp_response = self._requests_session.post( + url, + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + data=json.dumps({'stateToken': statetoken, + 'answer': response}) + ) + totp_parsed = json.loads(totp_response.text) + if totp_response.status_code == 200: + return self.get_assertion_from_response(endpoint, totp_parsed) + elif totp_response.status_code >= 400: + print totp_parsed["errorCauses"][0]["errorSummary"] + + def verify_sms_factor(self, url, statetoken, passcode): + body = {'stateToken': statetoken} + if passcode != "": + body['passCode'] = passcode + return self._requests_session.post( + url, + headers={'Content-Type': 'application/json', + 'Accept': 'application/json'}, + data=json.dumps(body) + ) + + def process_mfa_sms(self, endpoint, url, statetoken): + # Need to trigger the initial code to be sent ... + print "Requesting code to be sent to your phone ..." + self.verify_sms_factor(url, statetoken, "") + while True: + response = self.get_response(self._MSG_SMS_CODE) + if response == "RESEND": + response = "" + sms_response = self.verify_sms_factor(url, statetoken, response) + # If we've just requested a resend, don't check the result + # - just loop around to get the next response from the user. + if response != "": + sms_parsed = json.loads(sms_response.text) + if sms_response.status_code == 200: + return self.get_assertion_from_response(endpoint, + sms_parsed) + elif sms_response.status_code >= 400: + print sms_parsed["errorCauses"][0]["errorSummary"] + + def display_mfa_choices(self, parsed): + index = 1 + for f in parsed["_embedded"]["factors"]: + if f["factorType"] == "token": + print "%s: %s token" % (index, f["provider"]) + elif f["factorType"] == "token:software:totp": + print "%s: %s authenticator app" % (index, f["provider"]) + elif f["factorType"] == "sms": + print "%s: SMS text message" % index + elif f["factorType"] == "push": + print "%s: Push notification" % index + elif f["factorType"] == "question": + print "%s: Security question" % index + else: + print "%s: %s %s" % (index, f["provider"], f["factorType"]) + index += 1 + return index + + def get_mfa_choice(self, parsed): + while True: + print "Please choose from the following authentication choices:" + count = self.display_mfa_choices(parsed) + print ("Enter the number corresponding to your choice " + "or press RETURN") + response = self.get_response("to cancel authentication: ") + choice = 0 + try: + choice = int(response) + except ValueError: + pass + if choice > 0 and choice < count: + return choice + + def process_mfa_verification(self, endpoint, parsed): + # If we've only got one factor, pick that automatically + if len(parsed["_embedded"]["factors"]) == 1: + choice = 1 + else: + choice = self.get_mfa_choice(parsed) + factor = parsed["_embedded"]["factors"][choice - 1] + url = factor["_links"]["verify"]["href"] + statetoken = parsed["stateToken"] + if factor["factorType"] == "token:software:totp": + return self.process_mfa_totp(endpoint, url, statetoken) + elif factor["factorType"] == "push": + return self.process_mfa_push(endpoint, url, statetoken) + elif factor["factorType"] == "question": + return self.process_mfa_security_question(endpoint, + url, statetoken) + elif factor["factorType"] == "sms": + return self.process_mfa_sms(endpoint, url, statetoken) + else: + raise SAMLError("Unsupported factor") + def retrieve_saml_assertion(self, config): self._validate_config_values(config) endpoint = config['saml_endpoint'] @@ -235,6 +420,30 @@ def retrieve_saml_assertion(self, config): 'password': password}) ) parsed = json.loads(response.text) + logger.info( + 'Got status %s and response: %s', + response.status_code, response.text + ) + if response.status_code == 401: + raise SAMLError(self._ERROR_LOGIN_FAILED_NON_200 % + parsed["errorSummary"]) + if "status" in parsed: + if parsed["status"] == "SUCCESS": + return self.get_assertion_from_response(endpoint, parsed) + elif parsed["status"] == "LOCKED_OUT": + raise SAMLError(self._ERROR_LOCKED_OUT % + parsed["_links"]["href"]) + elif parsed["status"] == "PASSWORD_EXPIRED": + raise SAMLError(self._ERROR_PASSWORD_EXPIRED % + parsed["_links"]["href"]) + elif parsed["status"] == "MFA_ENROLL": + raise SAMLError(self._ERROR_MFA_ENROLL) + elif parsed["status"] == "MFA_REQUIRED": + return self.process_mfa_verification(endpoint, parsed) + # If we get to here, the chances are we're running the functional + # tests and NOT running against Okta's service so keep the + # original code to keep the tests happy as the tests don't use + # valid Okta responses ... session_token = parsed['sessionToken'] saml_url = endpoint + '?sessionToken=%s' % session_token response = self._requests_session.get(saml_url) From bc4dfcc502f25fa6618101524615d0f9ce9b0750 Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Fri, 22 Dec 2017 15:59:27 +0000 Subject: [PATCH 02/13] Python 3 fixes. --- awsprocesscreds/saml.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 36d3f27..23c9215 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -284,10 +284,10 @@ def process_mfa_totp(self, endpoint, url, statetoken): if totp_response.status_code == 200: return self.get_assertion_from_response(endpoint, totp_parsed) elif totp_response.status_code >= 400: - print totp_parsed["errorCauses"][0]["errorSummary"] + print(totp_parsed["errorCauses"][0]["errorSummary"]) def process_mfa_push(self, endpoint, url, statetoken): - print "Waiting for result of push notification ..." + print("Waiting for result of push notification ...") while True: totp_response = self._requests_session.post( url, @@ -315,7 +315,7 @@ def process_mfa_security_question(self, endpoint, url, statetoken): if totp_response.status_code == 200: return self.get_assertion_from_response(endpoint, totp_parsed) elif totp_response.status_code >= 400: - print totp_parsed["errorCauses"][0]["errorSummary"] + print(totp_parsed["errorCauses"][0]["errorSummary"]) def verify_sms_factor(self, url, statetoken, passcode): body = {'stateToken': statetoken} @@ -330,7 +330,7 @@ def verify_sms_factor(self, url, statetoken, passcode): def process_mfa_sms(self, endpoint, url, statetoken): # Need to trigger the initial code to be sent ... - print "Requesting code to be sent to your phone ..." + print("Requesting code to be sent to your phone ...") self.verify_sms_factor(url, statetoken, "") while True: response = self.get_response(self._MSG_SMS_CODE) @@ -345,32 +345,32 @@ def process_mfa_sms(self, endpoint, url, statetoken): return self.get_assertion_from_response(endpoint, sms_parsed) elif sms_response.status_code >= 400: - print sms_parsed["errorCauses"][0]["errorSummary"] + print(sms_parsed["errorCauses"][0]["errorSummary"]) def display_mfa_choices(self, parsed): index = 1 for f in parsed["_embedded"]["factors"]: if f["factorType"] == "token": - print "%s: %s token" % (index, f["provider"]) + print("%s: %s token" % (index, f["provider"])) elif f["factorType"] == "token:software:totp": - print "%s: %s authenticator app" % (index, f["provider"]) + print("%s: %s authenticator app" % (index, f["provider"])) elif f["factorType"] == "sms": - print "%s: SMS text message" % index + print("%s: SMS text message" % index) elif f["factorType"] == "push": - print "%s: Push notification" % index + print("%s: Push notification" % index) elif f["factorType"] == "question": - print "%s: Security question" % index + print("%s: Security question" % index) else: - print "%s: %s %s" % (index, f["provider"], f["factorType"]) + print("%s: %s %s" % (index, f["provider"], f["factorType"])) index += 1 return index def get_mfa_choice(self, parsed): while True: - print "Please choose from the following authentication choices:" + print("Please choose from the following authentication choices:") count = self.display_mfa_choices(parsed) - print ("Enter the number corresponding to your choice " - "or press RETURN") + print("Enter the number corresponding to your choice " + "or press RETURN") response = self.get_response("to cancel authentication: ") choice = 0 try: From 7f2bd59a735ccdd5f0b1fdebcf36d70ef825c595 Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Fri, 22 Dec 2017 16:08:44 +0000 Subject: [PATCH 03/13] Python 2 & 3 compatibility fixes. --- awsprocesscreds/saml.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 23c9215..4cf79ad 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -1,5 +1,4 @@ -from sys import version_info - +from __future__ import print_function import base64 import getpass import logging @@ -18,6 +17,7 @@ import botocore.session from .compat import escape +from six.moves import input class SAMLError(Exception): @@ -248,11 +248,7 @@ class OktaAuthenticator(GenericFormsBasedAuthenticator): ) def get_response(self, prompt): - py3 = version_info[0] > 2 - if py3: - response = input(prompt) - else: - response = raw_input(prompt) + response = input(prompt) if response == "": raise SAMLError(self._ERROR_AUTH_CANCELLED) return response From 556dad93fbb7a191e3c71ce04cf0bde53287ed23 Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Fri, 22 Dec 2017 16:11:14 +0000 Subject: [PATCH 04/13] Should have run prcheck before last commit! --- awsprocesscreds/saml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 4cf79ad..8cba9eb 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -7,6 +7,7 @@ from copy import deepcopy import six +from six.moves import input import requests import botocore from botocore.client import Config @@ -17,7 +18,6 @@ import botocore.session from .compat import escape -from six.moves import input class SAMLError(Exception): From b428332f25ce891b69cecff6a711a9e39f6d771e Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Wed, 17 Jan 2018 11:55:05 +0000 Subject: [PATCH 05/13] Get getpass instead of print. --- awsprocesscreds/saml.py | 45 ++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 8cba9eb..cce057d 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -280,10 +280,13 @@ def process_mfa_totp(self, endpoint, url, statetoken): if totp_response.status_code == 200: return self.get_assertion_from_response(endpoint, totp_parsed) elif totp_response.status_code >= 400: - print(totp_parsed["errorCauses"][0]["errorSummary"]) + error = totp_parsed["errorCauses"][0]["errorSummary"] + getpass.getpass("%s\r\nPress RETURN to continue\r\n" + % error) def process_mfa_push(self, endpoint, url, statetoken): - print("Waiting for result of push notification ...") + getpass.getpass(("Waiting for result of push notification ..." + "press RETURN to continue")) while True: totp_response = self._requests_session.post( url, @@ -311,7 +314,9 @@ def process_mfa_security_question(self, endpoint, url, statetoken): if totp_response.status_code == 200: return self.get_assertion_from_response(endpoint, totp_parsed) elif totp_response.status_code >= 400: - print(totp_parsed["errorCauses"][0]["errorSummary"]) + error = totp_parsed["errorCauses"][0]["errorSummary"] + getpass.getpass("%s\r\nPress RETURN to continue\r\n" + % error) def verify_sms_factor(self, url, statetoken, passcode): body = {'stateToken': statetoken} @@ -326,7 +331,8 @@ def verify_sms_factor(self, url, statetoken, passcode): def process_mfa_sms(self, endpoint, url, statetoken): # Need to trigger the initial code to be sent ... - print("Requesting code to be sent to your phone ...") + getpass.getpass(("Requesting code to be sent to your phone ..." + " press RETURN to continue")) self.verify_sms_factor(url, statetoken, "") while True: response = self.get_response(self._MSG_SMS_CODE) @@ -341,32 +347,39 @@ def process_mfa_sms(self, endpoint, url, statetoken): return self.get_assertion_from_response(endpoint, sms_parsed) elif sms_response.status_code >= 400: - print(sms_parsed["errorCauses"][0]["errorSummary"]) + error = sms_parsed["errorCauses"][0]["errorSummary"] + getpass.getpass("%s\r\nPress RETURN to continue\r\n" + % error) def display_mfa_choices(self, parsed): index = 1 + prompt = "" for f in parsed["_embedded"]["factors"]: if f["factorType"] == "token": - print("%s: %s token" % (index, f["provider"])) + prompt += "%s: %s token\r\n" % (index, f["provider"]) elif f["factorType"] == "token:software:totp": - print("%s: %s authenticator app" % (index, f["provider"])) + prompt += ("%s: %s authenticator app\r\n" + % (index, f["provider"])) elif f["factorType"] == "sms": - print("%s: SMS text message" % index) + prompt += "%s: SMS text message\r\n" % index elif f["factorType"] == "push": - print("%s: Push notification" % index) + prompt += "%s: Push notification\r\n" % index elif f["factorType"] == "question": - print("%s: Security question" % index) + prompt += "%s: Security question\r\n" % index else: - print("%s: %s %s" % (index, f["provider"], f["factorType"])) + prompt += "%s: %s %s\r\n" % (index, + f["provider"], + f["factorType"]) index += 1 - return index + return index, prompt def get_mfa_choice(self, parsed): while True: - print("Please choose from the following authentication choices:") - count = self.display_mfa_choices(parsed) - print("Enter the number corresponding to your choice " - "or press RETURN") + count, prompt = self.display_mfa_choices(parsed) + prompt = ("Please choose from the following authentication" + "choices:\r\n") + prompt + prompt += ("Enter the number corresponding to your choice " + "or press RETURN\r\n") response = self.get_response("to cancel authentication: ") choice = 0 try: From 65539962e7cedbfebf9a6d032df792ab8d31bb2e Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Wed, 17 Jan 2018 12:00:34 +0000 Subject: [PATCH 06/13] Fix prompting oversight. --- awsprocesscreds/saml.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index cce057d..32056ea 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -379,8 +379,8 @@ def get_mfa_choice(self, parsed): prompt = ("Please choose from the following authentication" "choices:\r\n") + prompt prompt += ("Enter the number corresponding to your choice " - "or press RETURN\r\n") - response = self.get_response("to cancel authentication: ") + "or press RETURN to cancel authentication: ") + response = self.get_response(prompt) choice = 0 try: choice = int(response) From 064aa03a4b9f00a661c5684f33ecb62b930ec0ab Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Wed, 17 Jan 2018 12:06:08 +0000 Subject: [PATCH 07/13] Change to use password prompter. --- awsprocesscreds/saml.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 32056ea..7b91043 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -281,12 +281,12 @@ def process_mfa_totp(self, endpoint, url, statetoken): return self.get_assertion_from_response(endpoint, totp_parsed) elif totp_response.status_code >= 400: error = totp_parsed["errorCauses"][0]["errorSummary"] - getpass.getpass("%s\r\nPress RETURN to continue\r\n" - % error) + _password_prompter("%s\r\nPress RETURN to continue\r\n" + % error) def process_mfa_push(self, endpoint, url, statetoken): - getpass.getpass(("Waiting for result of push notification ..." - "press RETURN to continue")) + _password_prompter(("Waiting for result of push notification ..." + "press RETURN to continue")) while True: totp_response = self._requests_session.post( url, @@ -315,8 +315,8 @@ def process_mfa_security_question(self, endpoint, url, statetoken): return self.get_assertion_from_response(endpoint, totp_parsed) elif totp_response.status_code >= 400: error = totp_parsed["errorCauses"][0]["errorSummary"] - getpass.getpass("%s\r\nPress RETURN to continue\r\n" - % error) + _password_prompter("%s\r\nPress RETURN to continue\r\n" + % error) def verify_sms_factor(self, url, statetoken, passcode): body = {'stateToken': statetoken} @@ -331,8 +331,8 @@ def verify_sms_factor(self, url, statetoken, passcode): def process_mfa_sms(self, endpoint, url, statetoken): # Need to trigger the initial code to be sent ... - getpass.getpass(("Requesting code to be sent to your phone ..." - " press RETURN to continue")) + _password_prompter(("Requesting code to be sent to your phone ..." + " press RETURN to continue")) self.verify_sms_factor(url, statetoken, "") while True: response = self.get_response(self._MSG_SMS_CODE) @@ -348,8 +348,8 @@ def process_mfa_sms(self, endpoint, url, statetoken): sms_parsed) elif sms_response.status_code >= 400: error = sms_parsed["errorCauses"][0]["errorSummary"] - getpass.getpass("%s\r\nPress RETURN to continue\r\n" - % error) + _password_prompter("%s\r\nPress RETURN to continue\r\n" + % error) def display_mfa_choices(self, parsed): index = 1 @@ -380,7 +380,7 @@ def get_mfa_choice(self, parsed): "choices:\r\n") + prompt prompt += ("Enter the number corresponding to your choice " "or press RETURN to cancel authentication: ") - response = self.get_response(prompt) + response = self._password_prompter(prompt) choice = 0 try: choice = int(response) From aac125f92cf2f46db1db6ab4546cea20976a323b Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Wed, 17 Jan 2018 12:09:06 +0000 Subject: [PATCH 08/13] Fix missing "self." --- awsprocesscreds/saml.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 7b91043..94b59f0 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -281,12 +281,12 @@ def process_mfa_totp(self, endpoint, url, statetoken): return self.get_assertion_from_response(endpoint, totp_parsed) elif totp_response.status_code >= 400: error = totp_parsed["errorCauses"][0]["errorSummary"] - _password_prompter("%s\r\nPress RETURN to continue\r\n" - % error) + self._password_prompter("%s\r\nPress RETURN to continue\r\n" + % error) def process_mfa_push(self, endpoint, url, statetoken): - _password_prompter(("Waiting for result of push notification ..." - "press RETURN to continue")) + self._password_prompter(("Waiting for result of push notification ..." + "press RETURN to continue")) while True: totp_response = self._requests_session.post( url, @@ -315,8 +315,8 @@ def process_mfa_security_question(self, endpoint, url, statetoken): return self.get_assertion_from_response(endpoint, totp_parsed) elif totp_response.status_code >= 400: error = totp_parsed["errorCauses"][0]["errorSummary"] - _password_prompter("%s\r\nPress RETURN to continue\r\n" - % error) + self._password_prompter("%s\r\nPress RETURN to continue\r\n" + % error) def verify_sms_factor(self, url, statetoken, passcode): body = {'stateToken': statetoken} @@ -331,8 +331,8 @@ def verify_sms_factor(self, url, statetoken, passcode): def process_mfa_sms(self, endpoint, url, statetoken): # Need to trigger the initial code to be sent ... - _password_prompter(("Requesting code to be sent to your phone ..." - " press RETURN to continue")) + self._password_prompter(("Requesting code to be sent to your phone ..." + " press RETURN to continue")) self.verify_sms_factor(url, statetoken, "") while True: response = self.get_response(self._MSG_SMS_CODE) @@ -348,8 +348,9 @@ def process_mfa_sms(self, endpoint, url, statetoken): sms_parsed) elif sms_response.status_code >= 400: error = sms_parsed["errorCauses"][0]["errorSummary"] - _password_prompter("%s\r\nPress RETURN to continue\r\n" - % error) + self._password_prompter(("%s\r\n" + "Press RETURN to continue\r\n") + % error) def display_mfa_choices(self, parsed): index = 1 From 81e19debbf7333e0e62d3f21c922211d6e3aff5b Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Wed, 17 Jan 2018 12:10:22 +0000 Subject: [PATCH 09/13] Add a missing space. --- awsprocesscreds/saml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 94b59f0..777c51d 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -378,7 +378,7 @@ def get_mfa_choice(self, parsed): while True: count, prompt = self.display_mfa_choices(parsed) prompt = ("Please choose from the following authentication" - "choices:\r\n") + prompt + " choices:\r\n") + prompt prompt += ("Enter the number corresponding to your choice " "or press RETURN to cancel authentication: ") response = self._password_prompter(prompt) From 583dce67cdb3ad64fddb2ed7a418f656e560e08b Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Thu, 3 May 2018 08:46:01 +0100 Subject: [PATCH 10/13] Improve testing. --- awsprocesscreds/saml.py | 25 +++------ tests/functional/test_saml.py | 102 +++++++++++++++++++++++++++++++--- tests/unit/test_saml.py | 4 +- 3 files changed, 106 insertions(+), 25 deletions(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 777c51d..99a0245 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -1,4 +1,5 @@ from __future__ import print_function +import sys import base64 import getpass import logging @@ -247,8 +248,13 @@ class OktaAuthenticator(GenericFormsBasedAuthenticator): "'RESEND' to get new code sent): " ) + def __obtain_input(self, text): + if sys.version_info >= (3, 0): + return input(text) + return raw_input(text) # noqa + def get_response(self, prompt): - response = input(prompt) + response = self.__obtain_input(prompt) if response == "": raise SAMLError(self._ERROR_AUTH_CANCELLED) return response @@ -450,21 +456,7 @@ def retrieve_saml_assertion(self, config): raise SAMLError(self._ERROR_MFA_ENROLL) elif parsed["status"] == "MFA_REQUIRED": return self.process_mfa_verification(endpoint, parsed) - # If we get to here, the chances are we're running the functional - # tests and NOT running against Okta's service so keep the - # original code to keep the tests happy as the tests don't use - # valid Okta responses ... - session_token = parsed['sessionToken'] - saml_url = endpoint + '?sessionToken=%s' % session_token - response = self._requests_session.get(saml_url) - logger.info( - 'Received HTTP response of status code: %s', response.status_code) - r = self._extract_saml_assertion_from_response(response.text) - logger.info( - 'Received the following SAML assertion: \n%s', r, - extra={'is_saml_assertion': True} - ) - return r + raise SAMLError("Code logic failure") def is_suitable(self, config): return (config.get('saml_authentication_type') == 'form' and @@ -526,7 +518,6 @@ class SAMLCredentialFetcher(CachedCredentialFetcher): SAML_FORM_AUTHENTICATORS = { 'okta': OktaAuthenticator, 'adfs': ADFSFormsBasedAuthenticator - } def __init__(self, client_creator, provider_name, saml_config, diff --git a/tests/functional/test_saml.py b/tests/functional/test_saml.py index dcaf836..f24acf9 100644 --- a/tests/functional/test_saml.py +++ b/tests/functional/test_saml.py @@ -1,3 +1,4 @@ +import sys import mock import json import logging @@ -9,7 +10,8 @@ from tests import create_assertion from awsprocesscreds.cli import saml, PrettyPrinterLogHandler -from awsprocesscreds.saml import SAMLCredentialFetcher +from awsprocesscreds.saml import SAMLCredentialFetcher, OktaAuthenticator, \ + SAMLError @pytest.fixture @@ -22,9 +24,95 @@ def argv(): ] +def test_get_response(): + authenticator = OktaAuthenticator(None) + if sys.version_info >= (3, 0): + ips = "builtins.input" + else: + ips = "__builtin__.raw_input" + + with mock.patch(ips, return_value=""): + with pytest.raises(SAMLError): + authenticator.get_response("") + with mock.patch(ips, return_value="fake input"): + response = authenticator.get_response("") + assert response == "fake input" + + +def test_get_assertion_from_response(mock_requests_session, assertion): + authenticator = OktaAuthenticator(None) + session_token = { + 'sessionToken': '1234', + } + assertion_form = '
' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + mock_requests_session.get.return_value = assertion_response + result = authenticator.get_assertion_from_response( + "endpoint", session_token) + assert result == assertion + + +# def test_process_mfa_totp( +# mock_requests_session, prompter, assertion, capsys): +# authenticator = SAMLCredentialFetcher( +# client_creator=None, +# provider_name="okta", +# saml_config=None, +# password_prompter=prompter) +# if sys.version_info >= (3, 0): +# ips = "builtins.input" +# else: +# ips = "__builtin__.raw_input" +# session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} +# token_response = mock.Mock( +# spec=requests.Response, +# status_code=200, +# text=json.dumps(session_token) +# ) +# assertion_form = '
' +# assertion_form = assertion_form % assertion.decode('ascii') +# assertion_response = mock.Mock( +# spec=requests.Response, status_code=200, text=assertion_form +# ) + +# mock_requests_session.post.return_value = token_response +# mock_requests_session.get.return_value = assertion_response + +# with mock.patch("getpass.getpass", return_value="12345678"): +# with mock.patch(ips, return_value="12345678"): +# result = authenticator._authenticator.process_mfa_totp( +# "endpoint", "url", "statetoken") +# assert result == assertion + +# # Now test the handling of a 400 error code +# error_response = { +# "errorCauses": [ +# { +# "errorSummary": "errorSummary" +# } +# ] +# } +# token_response = mock.Mock( +# spec=requests.Response, +# status_code=400, +# text=json.dumps(error_response) +# ) +# mock_requests_session.post.return_value = token_response +# with mock.patch("getpass.getpass", return_value="12345678"): +# with mock.patch(ips, return_value="12345678"): +# result = authenticator._authenticator.process_mfa_totp( +# "endpoint", "url", "statetoken") +# stdout, _ = capsys.readouterr() +# assert stdout.endswith('\n') +# assert result == assertion + + def test_cli(mock_requests_session, argv, prompter, assertion, client_creator, capsys, cache_dir): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) @@ -55,7 +143,7 @@ def test_cli(mock_requests_session, argv, prompter, assertion, client_creator, def test_no_cache(mock_requests_session, argv, prompter, assertion, client_creator, capsys, cache_dir): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) @@ -92,7 +180,7 @@ def test_no_cache(mock_requests_session, argv, prompter, assertion, def test_verbose(mock_requests_session, argv, prompter, assertion, client_creator, cache_dir): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) @@ -123,7 +211,7 @@ def test_verbose(mock_requests_session, argv, prompter, assertion, def test_log_handler_parses_assertion(mock_requests_session, argv, prompter, client_creator, cache_dir, caplog): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) @@ -160,7 +248,7 @@ def test_log_handler_parses_assertion(mock_requests_session, argv, prompter, def test_log_handler_parses_dict(mock_requests_session, argv, prompter, client_creator, cache_dir, caplog): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) @@ -237,7 +325,7 @@ def test_unsupported_saml_provider(client_creator, prompter): def test_prompter_only_called_once(client_creator, prompter, assertion, mock_requests_session): - session_token = {'sessionToken': 'spam'} + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, status_code=200, text=json.dumps(session_token) ) diff --git a/tests/unit/test_saml.py b/tests/unit/test_saml.py index db2e218..a87d37f 100644 --- a/tests/unit/test_saml.py +++ b/tests/unit/test_saml.py @@ -373,7 +373,9 @@ def test_authn_requests_made(self, okta_auth, okta_config, session_token = 'mytoken' # 1st response is for authentication. mock_requests_session.post.return_value = mock.Mock( - text=json.dumps({"sessionToken": session_token}), + text=json.dumps( + {"sessionToken": session_token, "status": "SUCCESS"} + ), status_code=200 ) # 2nd response is to then retrieve the assertion. From f0078a5744f64bbf3f9a6aff70975ea20619721f Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Fri, 16 Nov 2018 11:42:46 +0000 Subject: [PATCH 11/13] Add missing tests --- .gitignore | 1 + awsprocesscreds/saml.py | 9 +- tests/functional/test_saml.py | 608 ++++++++++++++++++++++++++++++---- 3 files changed, 540 insertions(+), 78 deletions(-) diff --git a/.gitignore b/.gitignore index 5b338fe..d77cc3b 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,7 @@ venv/ ENV/ env.bak/ venv.bak/ +venv-awsprocesscreds/ # mypy .mypy_cache/ diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 99a0245..4fa5873 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -1,5 +1,4 @@ from __future__ import print_function -import sys import base64 import getpass import logging @@ -248,13 +247,11 @@ class OktaAuthenticator(GenericFormsBasedAuthenticator): "'RESEND' to get new code sent): " ) - def __obtain_input(self, text): - if sys.version_info >= (3, 0): - return input(text) - return raw_input(text) # noqa + def obtain_input(self, text): + return input(text) def get_response(self, prompt): - response = self.__obtain_input(prompt) + response = self.obtain_input(prompt) if response == "": raise SAMLError(self._ERROR_AUTH_CANCELLED) return response diff --git a/tests/functional/test_saml.py b/tests/functional/test_saml.py index f24acf9..baf26ee 100644 --- a/tests/functional/test_saml.py +++ b/tests/functional/test_saml.py @@ -1,4 +1,3 @@ -import sys import mock import json import logging @@ -24,90 +23,555 @@ def argv(): ] -def test_get_response(): +@mock.patch( + 'awsprocesscreds.saml.OktaAuthenticator.obtain_input', + return_value="", + autospec=True +) +def test_get_response_1(mock_obtain_input): authenticator = OktaAuthenticator(None) - if sys.version_info >= (3, 0): - ips = "builtins.input" - else: - ips = "__builtin__.raw_input" + with pytest.raises(SAMLError): + authenticator.get_response("") - with mock.patch(ips, return_value=""): - with pytest.raises(SAMLError): - authenticator.get_response("") - with mock.patch(ips, return_value="fake input"): - response = authenticator.get_response("") - assert response == "fake input" +@mock.patch( + 'awsprocesscreds.saml.OktaAuthenticator.obtain_input', + return_value="mock_result", + autospec=True +) +def test_get_response_2(mock_obtain_input): + authenticator = OktaAuthenticator(None) + response = authenticator.get_response("") + assert response == "mock_result" + + +def test_process_mfa_totp( + mock_requests_session, prompter, assertion, capsys): + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + assertion_form = '
' + assertion_form = assertion_form % assertion.decode() + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) -def test_get_assertion_from_response(mock_requests_session, assertion): + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.obtain_input", + return_value="12345678"): + result = authenticator._authenticator.process_mfa_totp( + "endpoint", "url", "statetoken") + assert result == assertion.decode() + + +def test_process_mfa_push_1( + mock_requests_session, prompter, assertion, capsys): + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + assertion_form = '
' + assertion_form = assertion_form % assertion.decode() + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + + result = authenticator._authenticator.process_mfa_push( + "endpoint", "url", "statetoken") + assert result == assertion.decode() + + +def test_process_mfa_push_2( + mock_requests_session, prompter, assertion, capsys): + session_token = { + 'sessionToken': 'spam', + 'status': 'CANCELLED', + 'factorResult': 'FAILED' + } + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + mock_requests_session.post.return_value = token_response + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + + with pytest.raises(SAMLError): + authenticator._authenticator.process_mfa_push( + "endpoint", "url", "statetoken") + + +def test_process_mfa_security_question( + mock_requests_session, prompter, assertion, capsys): + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + assertion_form = '
' + assertion_form = assertion_form % assertion.decode() + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.get_response", + return_value="security_answer"): + result = authenticator._authenticator.process_mfa_security_question( + "endpoint", "url", "statetoken") + assert result == assertion.decode() + + +def test_verify_sms_factor( + mock_requests_session, prompter, assertion, capsys): + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + mock_requests_session.post.return_value = token_response + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + result = authenticator._authenticator.verify_sms_factor( + "url", "statetoken", "passcode") + assert result.status_code == 200 + test = json.loads(result.text) + assert test["status"] == "SUCCESS" + + +def test_process_mfa_sms( + mock_requests_session, prompter, assertion, capsys): + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + assertion_form = '
' + assertion_form = assertion_form % assertion.decode() + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.get_response", + return_value="12345678"): + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.verify_sms_factor", + return_value=token_response): + result = authenticator._authenticator.process_mfa_sms( + "endpoint", "url", "statetoken") + assert result == assertion.decode() + + +def test_display_mfa_choices( + mock_requests_session, prompter, assertion, capsys): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "token", + "provider": "OKTA" + }, + { + "factorType": "token:software:totp", + "provider": "OKTA" + }, + { + "factorType": "sms" + }, + { + "factorType": "push" + }, + { + "factorType": "question" + }, + { + "factorType": "blackboard", + "provider": "classroom" + } + ] + } + } + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + index, prompt = authenticator._authenticator.display_mfa_choices(parsed) + assert index == 7 + assert prompt == ( + "1: OKTA token\r\n" + "2: OKTA authenticator app\r\n" + "3: SMS text message\r\n" + "4: Push notification\r\n" + "5: Security question\r\n" + "6: classroom blackboard\r\n" + ) + + +def test_get_mfa_choice( + mock_requests_session, prompter, assertion, capsys): + def mock_prompter(prompt): + assert prompt == ( + "Please choose from the following authentication choices:\r\n" + "1: SMS text message\r\n" + "Enter the number corresponding to your choice or press RETURN to " + "cancel authentication: " + ) + return "1" + + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "sms" + } + ] + } + } + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=mock_prompter) + response = authenticator._authenticator.get_mfa_choice(parsed) + assert response == 1 + + +def test_process_mfa_verification_1(): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "unsupported", + "_links": { + "verify": { + "href": "href" + } + } + }, + { + "factorType": "unsupported" + } + ] + }, + "stateToken": "statetoken" + } + authenticator = OktaAuthenticator(None) + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.get_mfa_choice", + return_value=1): + with pytest.raises(SAMLError): + authenticator.process_mfa_verification("endpoint", parsed) + + +def test_process_mfa_verification_2(): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "token:software:totp", + "_links": { + "verify": { + "href": "href" + } + } + } + ] + }, + "stateToken": "statetoken" + } + authenticator = OktaAuthenticator(None) + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.process_mfa_totp", + return_value="mock_call"): + result = authenticator.process_mfa_verification("endpoint", parsed) + assert result == "mock_call" + + +def test_process_mfa_verification_3(): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "push", + "_links": { + "verify": { + "href": "href" + } + } + } + ] + }, + "stateToken": "statetoken" + } authenticator = OktaAuthenticator(None) + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.process_mfa_push", + return_value="mock_call"): + result = authenticator.process_mfa_verification("endpoint", parsed) + assert result == "mock_call" + + +def test_process_mfa_verification_4(): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "question", + "_links": { + "verify": { + "href": "href" + } + } + } + ] + }, + "stateToken": "statetoken" + } + authenticator = OktaAuthenticator(None) + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator." + "process_mfa_security_question", + return_value="mock_call"): + result = authenticator.process_mfa_verification("endpoint", parsed) + assert result == "mock_call" + + +def test_process_mfa_verification_5(): + parsed = { + "_embedded": { + "factors": [ + { + "factorType": "sms", + "_links": { + "verify": { + "href": "href" + } + } + } + ] + }, + "stateToken": "statetoken" + } + authenticator = OktaAuthenticator(None) + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.process_mfa_sms", + return_value="mock_call"): + result = authenticator.process_mfa_verification("endpoint", parsed) + assert result == "mock_call" + + +def test_retrieve_saml_assertion_1( + mock_requests_session, argv, prompter, assertion, + client_creator, cache_dir): + session_token = { + 'sessionToken': 'spam', + 'status': 'FAILED', + 'errorSummary': 'Testing failure' + } + token_response = mock.Mock( + spec=requests.Response, status_code=401, text=json.dumps(session_token) + ) + assertion_form = '
' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + with pytest.raises(SAMLError): + saml(argv=argv, prompter=prompter, client_creator=client_creator, + cache_dir=cache_dir) + + +def test_retrieve_saml_assertion_2( + mock_requests_session, argv, prompter, assertion, + client_creator, cache_dir): session_token = { - 'sessionToken': '1234', + 'sessionToken': 'spam', + 'status': 'LOCKED_OUT', + '_links': { + 'href': 'href' + } } + token_response = mock.Mock( + spec=requests.Response, status_code=200, text=json.dumps(session_token) + ) assertion_form = '
' assertion_form = assertion_form % assertion.decode('ascii') assertion_response = mock.Mock( spec=requests.Response, status_code=200, text=assertion_form ) + + mock_requests_session.post.return_value = token_response mock_requests_session.get.return_value = assertion_response - result = authenticator.get_assertion_from_response( - "endpoint", session_token) - assert result == assertion - - -# def test_process_mfa_totp( -# mock_requests_session, prompter, assertion, capsys): -# authenticator = SAMLCredentialFetcher( -# client_creator=None, -# provider_name="okta", -# saml_config=None, -# password_prompter=prompter) -# if sys.version_info >= (3, 0): -# ips = "builtins.input" -# else: -# ips = "__builtin__.raw_input" -# session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} -# token_response = mock.Mock( -# spec=requests.Response, -# status_code=200, -# text=json.dumps(session_token) -# ) -# assertion_form = '
' -# assertion_form = assertion_form % assertion.decode('ascii') -# assertion_response = mock.Mock( -# spec=requests.Response, status_code=200, text=assertion_form -# ) - -# mock_requests_session.post.return_value = token_response -# mock_requests_session.get.return_value = assertion_response - -# with mock.patch("getpass.getpass", return_value="12345678"): -# with mock.patch(ips, return_value="12345678"): -# result = authenticator._authenticator.process_mfa_totp( -# "endpoint", "url", "statetoken") -# assert result == assertion - -# # Now test the handling of a 400 error code -# error_response = { -# "errorCauses": [ -# { -# "errorSummary": "errorSummary" -# } -# ] -# } -# token_response = mock.Mock( -# spec=requests.Response, -# status_code=400, -# text=json.dumps(error_response) -# ) -# mock_requests_session.post.return_value = token_response -# with mock.patch("getpass.getpass", return_value="12345678"): -# with mock.patch(ips, return_value="12345678"): -# result = authenticator._authenticator.process_mfa_totp( -# "endpoint", "url", "statetoken") -# stdout, _ = capsys.readouterr() -# assert stdout.endswith('\n') -# assert result == assertion + with pytest.raises(SAMLError): + saml(argv=argv, prompter=prompter, client_creator=client_creator, + cache_dir=cache_dir) + + +def test_retrieve_saml_assertion_3( + mock_requests_session, argv, prompter, assertion, + client_creator, cache_dir): + session_token = { + 'sessionToken': 'spam', + 'status': 'PASSWORD_EXPIRED', + '_links': { + 'href': 'href' + } + } + token_response = mock.Mock( + spec=requests.Response, status_code=200, text=json.dumps(session_token) + ) + assertion_form = '
' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + with pytest.raises(SAMLError): + saml(argv=argv, prompter=prompter, client_creator=client_creator, + cache_dir=cache_dir) + + +def test_retrieve_saml_assertion_4( + mock_requests_session, argv, prompter, assertion, + client_creator, cache_dir): + session_token = { + 'sessionToken': 'spam', + 'status': 'MFA_ENROLL' + } + token_response = mock.Mock( + spec=requests.Response, status_code=200, text=json.dumps(session_token) + ) + assertion_form = '
' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + with pytest.raises(SAMLError): + saml(argv=argv, prompter=prompter, client_creator=client_creator, + cache_dir=cache_dir) + + +def test_retrieve_saml_assertion_5( + mock_requests_session, argv, prompter, assertion, + client_creator, capsys, cache_dir): + session_token = { + 'sessionToken': 'spam', + 'status': 'MFA_REQUIRED' + } + token_response = mock.Mock( + spec=requests.Response, status_code=200, text=json.dumps(session_token) + ) + assertion_form = '
' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + + with mock.patch( + "awsprocesscreds.saml.OktaAuthenticator.process_mfa_verification", + return_value=assertion): + saml(argv=argv, prompter=prompter, + client_creator=client_creator, + cache_dir=cache_dir) + + stdout, _ = capsys.readouterr() + assert stdout.endswith('\n') + + response = json.loads(stdout) + expected_response = { + "AccessKeyId": "foo", + "SecretAccessKey": "bar", + "SessionToken": "baz", + "Expiration": mock.ANY, + "Version": 1 + } + assert response == expected_response + + +def test_retrieve_saml_assertion_6( + mock_requests_session, argv, prompter, assertion, + client_creator, cache_dir): + session_token = { + 'sessionToken': 'spam' + } + token_response = mock.Mock( + spec=requests.Response, status_code=200, text=json.dumps(session_token) + ) + assertion_form = '
' + assertion_form = assertion_form % assertion.decode('ascii') + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + + mock_requests_session.post.return_value = token_response + mock_requests_session.get.return_value = assertion_response + with pytest.raises(SAMLError): + saml(argv=argv, prompter=prompter, client_creator=client_creator, + cache_dir=cache_dir) def test_cli(mock_requests_session, argv, prompter, assertion, client_creator, From d8ab8cef30cdfa1f7b9fc27f15ad1e7c4855c1e7 Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Tue, 20 Nov 2018 14:42:38 +0000 Subject: [PATCH 12/13] Improve input handling and tests --- awsprocesscreds/saml.py | 88 +++++++++---------- tests/functional/test_saml.py | 160 ++++++++++++++++++++++++++-------- 2 files changed, 165 insertions(+), 83 deletions(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 4fa5873..1b36b1c 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -7,7 +7,6 @@ from copy import deepcopy import six -from six.moves import input import requests import botocore from botocore.client import Config @@ -243,16 +242,13 @@ class OktaAuthenticator(GenericFormsBasedAuthenticator): ) _MSG_SMS_CODE = ( - "Authentication code (RETURN to cancel, " + "SMS authentication code (RETURN to cancel, " "'RESEND' to get new code sent): " ) - def obtain_input(self, text): - return input(text) - - def get_response(self, prompt): - response = self.obtain_input(prompt) - if response == "": + def get_response(self, prompt, allow_cancel=True): + response = self._password_prompter(prompt) + if allow_cancel and response == "": raise SAMLError(self._ERROR_AUTH_CANCELLED) return response @@ -269,6 +265,16 @@ def get_assertion_from_response(self, endpoint, parsed): ) return r + def process_response(self, response, endpoint): + parsed = json.loads(response.text) + if response.status_code == 200: + return self.get_assertion_from_response(endpoint, parsed) + elif response.status_code >= 400: + error = parsed["errorCauses"][0]["errorSummary"] + self.get_response("%s\r\nPress RETURN to continue\r\n" + % error, False) + return None + def process_mfa_totp(self, endpoint, url, statetoken): while True: response = self.get_response(self._MSG_AUTH_CODE) @@ -279,17 +285,13 @@ def process_mfa_totp(self, endpoint, url, statetoken): data=json.dumps({'stateToken': statetoken, 'passCode': response}) ) - totp_parsed = json.loads(totp_response.text) - if totp_response.status_code == 200: - return self.get_assertion_from_response(endpoint, totp_parsed) - elif totp_response.status_code >= 400: - error = totp_parsed["errorCauses"][0]["errorSummary"] - self._password_prompter("%s\r\nPress RETURN to continue\r\n" - % error) + result = self.process_response(totp_response, endpoint) + if result is not None: + return result def process_mfa_push(self, endpoint, url, statetoken): - self._password_prompter(("Waiting for result of push notification ..." - "press RETURN to continue")) + self.get_response(("Press RETURN when you are ready to request the " + "push notification"), False) while True: totp_response = self._requests_session.post( url, @@ -313,13 +315,9 @@ def process_mfa_security_question(self, endpoint, url, statetoken): data=json.dumps({'stateToken': statetoken, 'answer': response}) ) - totp_parsed = json.loads(totp_response.text) - if totp_response.status_code == 200: - return self.get_assertion_from_response(endpoint, totp_parsed) - elif totp_response.status_code >= 400: - error = totp_parsed["errorCauses"][0]["errorSummary"] - self._password_prompter("%s\r\nPress RETURN to continue\r\n" - % error) + result = self.process_response(totp_response, endpoint) + if result is not None: + return result def verify_sms_factor(self, url, statetoken, passcode): body = {'stateToken': statetoken} @@ -334,26 +332,20 @@ def verify_sms_factor(self, url, statetoken, passcode): def process_mfa_sms(self, endpoint, url, statetoken): # Need to trigger the initial code to be sent ... - self._password_prompter(("Requesting code to be sent to your phone ..." - " press RETURN to continue")) self.verify_sms_factor(url, statetoken, "") while True: response = self.get_response(self._MSG_SMS_CODE) + # If the user has asked for the code to be resent, clear + # the response to retrigger sending the code. if response == "RESEND": response = "" sms_response = self.verify_sms_factor(url, statetoken, response) # If we've just requested a resend, don't check the result # - just loop around to get the next response from the user. if response != "": - sms_parsed = json.loads(sms_response.text) - if sms_response.status_code == 200: - return self.get_assertion_from_response(endpoint, - sms_parsed) - elif sms_response.status_code >= 400: - error = sms_parsed["errorCauses"][0]["errorSummary"] - self._password_prompter(("%s\r\n" - "Press RETURN to continue\r\n") - % error) + result = self.process_response(sms_response, endpoint) + if result is not None: + return result def display_mfa_choices(self, parsed): index = 1 @@ -377,19 +369,23 @@ def display_mfa_choices(self, parsed): index += 1 return index, prompt + def get_number(self, prompt): + response = self.get_response(prompt) + choice = 0 + try: + choice = int(response) + except ValueError: + pass + return choice + def get_mfa_choice(self, parsed): + count, prompt = self.display_mfa_choices(parsed) + prompt = ("Please choose from the following authentication" + " choices:\r\n") + prompt + prompt += ("Enter the number corresponding to your choice " + "or press RETURN to cancel authentication: ") while True: - count, prompt = self.display_mfa_choices(parsed) - prompt = ("Please choose from the following authentication" - " choices:\r\n") + prompt - prompt += ("Enter the number corresponding to your choice " - "or press RETURN to cancel authentication: ") - response = self._password_prompter(prompt) - choice = 0 - try: - choice = int(response) - except ValueError: - pass + choice = self.get_number(prompt) if choice > 0 and choice < count: return choice diff --git a/tests/functional/test_saml.py b/tests/functional/test_saml.py index baf26ee..14ab407 100644 --- a/tests/functional/test_saml.py +++ b/tests/functional/test_saml.py @@ -23,30 +23,92 @@ def argv(): ] -@mock.patch( - 'awsprocesscreds.saml.OktaAuthenticator.obtain_input', - return_value="", - autospec=True -) -def test_get_response_1(mock_obtain_input): - authenticator = OktaAuthenticator(None) +def test_get_response_1(): + def mock_prompter(prompt): + return "" + + authenticator = OktaAuthenticator(mock_prompter) with pytest.raises(SAMLError): authenticator.get_response("") -@mock.patch( - 'awsprocesscreds.saml.OktaAuthenticator.obtain_input', - return_value="mock_result", - autospec=True -) -def test_get_response_2(mock_obtain_input): - authenticator = OktaAuthenticator(None) +def test_get_response_2(): + def mock_prompter(prompt): + return "mock_result" + + authenticator = OktaAuthenticator(mock_prompter) response = authenticator.get_response("") assert response == "mock_result" +def test_get_response_3(): + def mock_prompter(prompt): + return "" + + authenticator = OktaAuthenticator(mock_prompter) + response = authenticator.get_response("", False) + assert response == "" + + +def test_process_response_1(mock_requests_session, assertion, prompter): + assertion_form = '
' + assertion_form = assertion_form % assertion.decode() + assertion_response = mock.Mock( + spec=requests.Response, status_code=200, text=assertion_form + ) + mock_requests_session.get.return_value = assertion_response + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} + token_response = mock.Mock( + spec=requests.Response, + status_code=200, + text=json.dumps(session_token) + ) + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=prompter) + + result = authenticator._authenticator.process_response( + token_response, "endpoint") + assert result == assertion.decode() + + +def test_process_response_2(mock_requests_session, assertion, prompter): + def mock_prompter(prompt): + assert prompt == "Mock error\r\nPress RETURN to continue\r\n" + return "" + + session_token = { + 'sessionToken': 'spam', + 'status': 'FAILED', + 'errorCauses': [ + { + 'errorSummary': "Mock error" + } + ] + } + token_response = mock.Mock( + spec=requests.Response, + status_code=400, + text=json.dumps(session_token) + ) + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=mock_prompter) + + result = authenticator._authenticator.process_response( + token_response, "endpoint") + assert result is None + + def test_process_mfa_totp( mock_requests_session, prompter, assertion, capsys): + def mock_prompter(prompt): + return "12345678" + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, @@ -66,14 +128,11 @@ def test_process_mfa_totp( client_creator=None, saml_config=None, provider_name="okta", - password_prompter=prompter) + password_prompter=mock_prompter) - with mock.patch( - "awsprocesscreds.saml.OktaAuthenticator.obtain_input", - return_value="12345678"): - result = authenticator._authenticator.process_mfa_totp( - "endpoint", "url", "statetoken") - assert result == assertion.decode() + result = authenticator._authenticator.process_mfa_totp( + "endpoint", "url", "statetoken") + assert result == assertion.decode() def test_process_mfa_push_1( @@ -131,6 +190,9 @@ def test_process_mfa_push_2( def test_process_mfa_security_question( mock_requests_session, prompter, assertion, capsys): + def mock_prompter(prompt): + return "security_answer" + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, @@ -150,14 +212,11 @@ def test_process_mfa_security_question( client_creator=None, saml_config=None, provider_name="okta", - password_prompter=prompter) + password_prompter=mock_prompter) - with mock.patch( - "awsprocesscreds.saml.OktaAuthenticator.get_response", - return_value="security_answer"): - result = authenticator._authenticator.process_mfa_security_question( - "endpoint", "url", "statetoken") - assert result == assertion.decode() + result = authenticator._authenticator.process_mfa_security_question( + "endpoint", "url", "statetoken") + assert result == assertion.decode() def test_verify_sms_factor( @@ -183,6 +242,9 @@ def test_verify_sms_factor( def test_process_mfa_sms( mock_requests_session, prompter, assertion, capsys): + def mock_prompter(prompt): + return "12345678" + session_token = {'sessionToken': 'spam', 'status': 'SUCCESS'} token_response = mock.Mock( spec=requests.Response, @@ -202,16 +264,14 @@ def test_process_mfa_sms( client_creator=None, saml_config=None, provider_name="okta", - password_prompter=prompter) + password_prompter=mock_prompter) + with mock.patch( - "awsprocesscreds.saml.OktaAuthenticator.get_response", - return_value="12345678"): - with mock.patch( - "awsprocesscreds.saml.OktaAuthenticator.verify_sms_factor", - return_value=token_response): - result = authenticator._authenticator.process_mfa_sms( - "endpoint", "url", "statetoken") - assert result == assertion.decode() + "awsprocesscreds.saml.OktaAuthenticator.verify_sms_factor", + return_value=token_response): + result = authenticator._authenticator.process_mfa_sms( + "endpoint", "url", "statetoken") + assert result == assertion.decode() def test_display_mfa_choices( @@ -260,6 +320,32 @@ def test_display_mfa_choices( ) +def test_get_number_1(prompter): + def mock_prompter(prompt): + return "1" + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=mock_prompter) + response = authenticator._authenticator.get_number("") + assert response == 1 + + +def test_get_number_2(prompter): + def mock_prompter(prompt): + return "fred" + + authenticator = SAMLCredentialFetcher( + client_creator=None, + saml_config=None, + provider_name="okta", + password_prompter=mock_prompter) + response = authenticator._authenticator.get_number("") + assert response == 0 + + def test_get_mfa_choice( mock_requests_session, prompter, assertion, capsys): def mock_prompter(prompt): From d66a173925e9b2779ca21d0354756d086de93e8e Mon Sep 17 00:00:00 2001 From: Philip Colmer Date: Sat, 9 Feb 2019 07:35:41 +0000 Subject: [PATCH 13/13] Fix linter issues --- awsprocesscreds/saml.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 096adc3..72b54df 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -270,7 +270,7 @@ def process_response(self, response, endpoint): parsed = json.loads(response.text) if response.status_code == 200: return self.get_assertion_from_response(endpoint, parsed) - elif response.status_code >= 400: + if response.status_code >= 400: error = parsed["errorCauses"][0]["errorSummary"] self.get_response("%s\r\nPress RETURN to continue\r\n" % error, False) @@ -303,7 +303,7 @@ def process_mfa_push(self, endpoint, url, statetoken): totp_parsed = json.loads(totp_response.text) if totp_parsed["status"] == "SUCCESS": return self.get_assertion_from_response(endpoint, totp_parsed) - elif totp_parsed["factorResult"] != "WAITING": + if totp_parsed["factorResult"] != "WAITING": raise SAMLError(self._ERROR_AUTH_CANCELLED) def process_mfa_security_question(self, endpoint, url, statetoken): @@ -387,7 +387,7 @@ def get_mfa_choice(self, parsed): "or press RETURN to cancel authentication: ") while True: choice = self.get_number(prompt) - if choice > 0 and choice < count: + if 0 < choice < count: return choice def process_mfa_verification(self, endpoint, parsed): @@ -401,15 +401,15 @@ def process_mfa_verification(self, endpoint, parsed): statetoken = parsed["stateToken"] if factor["factorType"] == "token:software:totp": return self.process_mfa_totp(endpoint, url, statetoken) - elif factor["factorType"] == "push": + if factor["factorType"] == "push": return self.process_mfa_push(endpoint, url, statetoken) - elif factor["factorType"] == "question": + if factor["factorType"] == "question": return self.process_mfa_security_question(endpoint, url, statetoken) - elif factor["factorType"] == "sms": + if factor["factorType"] == "sms": return self.process_mfa_sms(endpoint, url, statetoken) - else: - raise SAMLError("Unsupported factor") + + raise SAMLError("Unsupported factor") def retrieve_saml_assertion(self, config): self._validate_config_values(config) @@ -440,15 +440,15 @@ def retrieve_saml_assertion(self, config): if "status" in parsed: if parsed["status"] == "SUCCESS": return self.get_assertion_from_response(endpoint, parsed) - elif parsed["status"] == "LOCKED_OUT": + if parsed["status"] == "LOCKED_OUT": raise SAMLError(self._ERROR_LOCKED_OUT % parsed["_links"]["href"]) - elif parsed["status"] == "PASSWORD_EXPIRED": + if parsed["status"] == "PASSWORD_EXPIRED": raise SAMLError(self._ERROR_PASSWORD_EXPIRED % parsed["_links"]["href"]) - elif parsed["status"] == "MFA_ENROLL": + if parsed["status"] == "MFA_ENROLL": raise SAMLError(self._ERROR_MFA_ENROLL) - elif parsed["status"] == "MFA_REQUIRED": + if parsed["status"] == "MFA_REQUIRED": return self.process_mfa_verification(endpoint, parsed) raise SAMLError("Code logic failure")