Skip to content

Commit d5a3f3e

Browse files
authored
Merge pull request #1 from groboclown/master
Copy awslabs#11 - add multiple form field names for use with different ADFS versions
2 parents a216e70 + 820cd42 commit d5a3f3e

File tree

2 files changed

+63
-16
lines changed

2 files changed

+63
-16
lines changed

awsprocesscreds/saml.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def retrieve_saml_assertion(self, config):
5757

5858

5959
class GenericFormsBasedAuthenticator(SAMLAuthenticator):
60-
USERNAME_FIELD = 'username'
61-
PASSWORD_FIELD = 'password'
60+
USERNAME_FIELDS = ('username',)
61+
PASSWORD_FIELDS = ('password',)
6262

6363
_ERROR_BAD_RESPONSE = (
6464
'Received a non-200 response (%s) when making a request to: %s'
@@ -175,13 +175,19 @@ def _parse_form_from_html(self, html):
175175

176176
def _fill_in_form_values(self, config, form_data):
177177
username = config['saml_username']
178-
if self.USERNAME_FIELD not in form_data:
178+
username_field = set(self.USERNAME_FIELDS).intersection(
179+
form_data.keys()
180+
)
181+
if not username_field:
179182
raise SAMLError(
180-
self._ERROR_MISSING_FORM_FIELD % self.USERNAME_FIELD)
181-
else:
182-
form_data[self.USERNAME_FIELD] = username
183-
if self.PASSWORD_FIELD in form_data:
184-
form_data[self.PASSWORD_FIELD] = self._password_prompter(
183+
self._ERROR_MISSING_FORM_FIELD % self.USERNAME_FIELDS)
184+
form_data[username_field.pop()] = username
185+
186+
password_field = set(self.PASSWORD_FIELDS).intersection(
187+
form_data.keys()
188+
)
189+
if password_field:
190+
form_data[password_field.pop()] = self._password_prompter(
185191
"Password: ")
186192

187193
def _send_form_post(self, login_url, form_data):
@@ -250,17 +256,27 @@ def retrieve_saml_assertion(self, config):
250256
return r
251257

252258
def is_suitable(self, config):
253-
return (config.get('saml_authentication_type') == 'form' and
254-
config.get('saml_provider') == 'okta')
259+
return (
260+
config.get('saml_authentication_type') == 'form'
261+
and config.get('saml_provider') == 'okta'
262+
)
255263

256264

257265
class ADFSFormsBasedAuthenticator(GenericFormsBasedAuthenticator):
258-
USERNAME_FIELD = 'ctl00$ContentPlaceHolder1$UsernameTextBox'
259-
PASSWORD_FIELD = 'ctl00$ContentPlaceHolder1$PasswordTextBox'
266+
USERNAME_FIELDS = (
267+
'ctl00$ContentPlaceHolder1$UsernameTextBox',
268+
'UserName',
269+
)
270+
PASSWORD_FIELDS = (
271+
'ctl00$ContentPlaceHolder1$PasswordTextBox',
272+
'Password',
273+
)
260274

261275
def is_suitable(self, config):
262-
return (config.get('saml_authentication_type') == 'form' and
263-
config.get('saml_provider') == 'adfs')
276+
return (
277+
config.get('saml_authentication_type') == 'form'
278+
and config.get('saml_provider') == 'adfs'
279+
)
264280

265281

266282
class FormParser(six.moves.html_parser.HTMLParser):

tests/unit/test_saml.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,8 @@ def test_non_adfs_not_suitable(self, adfs_auth):
423423
}
424424
assert not adfs_auth.is_suitable(config)
425425

426-
def test_uses_adfs_fields(self, adfs_auth, mock_requests_session,
427-
adfs_config):
426+
def test_uses_adfs_fields_newer(self, adfs_auth, mock_requests_session,
427+
adfs_config):
428428
adfs_login_form = (
429429
'<html>'
430430
'<form action="login">'
@@ -454,6 +454,37 @@ def test_uses_adfs_fields(self, adfs_auth, mock_requests_session,
454454
}
455455
)
456456

457+
def test_uses_adfs_fields_older(self, adfs_auth, mock_requests_session,
458+
adfs_config):
459+
adfs_login_form = (
460+
'<html>'
461+
'<form action="login">'
462+
'<input name="UserName"/>'
463+
'<input name="Password"/>'
464+
'</form>'
465+
'</html>'
466+
)
467+
mock_requests_session.get.return_value = mock.Mock(
468+
spec=requests.Response, status_code=200, text=adfs_login_form
469+
)
470+
mock_requests_session.post.return_value = mock.Mock(
471+
spec=requests.Response, status_code=200, text=(
472+
'<form><input name="SAMLResponse" '
473+
'value="fakeassertion"/></form>'
474+
)
475+
)
476+
477+
saml_assertion = adfs_auth.retrieve_saml_assertion(adfs_config)
478+
assert saml_assertion == 'fakeassertion'
479+
480+
mock_requests_session.post.assert_called_with(
481+
"https://example.com/login", verify=True,
482+
data={
483+
'UserName': 'monty',
484+
'Password': 'mypassword'
485+
}
486+
)
487+
457488

458489
class TestFormParser(object):
459490
def test_parse_form(self, basic_form):

0 commit comments

Comments
 (0)