diff --git a/td/client.py b/td/client.py index 792df70..843eccf 100644 --- a/td/client.py +++ b/td/client.py @@ -7,6 +7,8 @@ import requests import urllib.parse +from authlib.integrations.httpx_client import OAuth2Client + from typing import Any from typing import Dict from typing import List @@ -52,7 +54,7 @@ class TDClient(): """ def __init__(self, client_id: str, redirect_uri: str, account_number: str = None, credentials_path: str = None, - auth_flow: str = 'default', _do_init: bool = True, _multiprocessing_safe = False) -> None: + auth_flow: str = 'default', _do_init: bool = True, _multiprocessing_safe = False, webdriver_path: str = "") -> None: """Creates a new instance of the TDClient Object. Initializes the session with default values and any user-provided overrides.The @@ -129,7 +131,10 @@ def __init__(self, client_id: str, redirect_uri: str, account_number: str = None self.client_id = client_id self.redirect_uri = redirect_uri self.account_number = account_number + self.webdriver_path = webdriver_path + self._token_endpoint = "https://api.tdameritrade.com/v1/oauth2/token" + self.credentials_path = pathlib.Path(credentials_path) self._td_utilities = TDUtilities() @@ -263,7 +268,10 @@ def login(self) -> bool: self.authstate = True return True else: - self.oauth() + if self.auth_flow == 'webdriver': + self.auth_using_webdriver() + else: + self.oauth() self.authstate = True return True @@ -298,7 +306,7 @@ def grab_access_token(self) -> dict: # Make the request. response = requests.post( - url="https://api.tdameritrade.com/v1/oauth2/token", + url=self._token_endpoint, headers={'Content-Type': 'application/x-www-form-urlencoded'}, data=data ) @@ -335,7 +343,7 @@ def grab_refresh_token(self) -> bool: # Make the request. response = requests.post( - url="https://api.tdameritrade.com/v1/oauth2/token", + url=self._token_endpoint, headers={'Content-Type': 'application/x-www-form-urlencoded'}, data=data ) @@ -389,6 +397,69 @@ def oauth(self) -> None: return_refresh_token=True ) + def auth_using_webdriver(self) -> None: + """Runs the oAuth process using webdriver for the TD Ameritrade API.""" + + print(f'Failed to find credentials json file \'{self.credentials_path}\'') + + from selenium import webdriver + with webdriver.Chrome(executable_path=self.webdriver_path) as driver: + self.auth_from_login_flow(driver) + + def _normalize_api_key(self, api_key): + api_key_suffix = '@AMER.OAUTHAP' + + if not api_key.endswith(api_key_suffix): + print(f'Appending {api_key_suffix} to API key') + api_key = api_key + api_key_suffix + return api_key + + class RedirectTimeoutError(Exception): + pass + + def auth_from_login_flow(self, driver): + print((f'Creating new token with redirect URL \'{self.redirect_uri}\' ' + + f'and credentials path \'{self.credentials_path}\'')) + + self.client_id = self._normalize_api_key(self.client_id) + + oauth = OAuth2Client(self.client_id, redirect_uri=self.redirect_uri) + authorization_url, state = oauth.create_authorization_url( + 'https://auth.tdameritrade.com/auth') + + driver.get(authorization_url) + + # Tolerate redirects to HTTPS on the callback URL + if self.redirect_uri.startswith('http://'): + redirect_urls = (self.redirect_uri, 'https' + self.redirect_uri[4:]) + else: + redirect_urls = (self.redirect_uri,) + + # Wait until the current URL starts with the callback URL + current_url = '' + num_waits = 0 + redirect_wait_time_seconds = 0.1 + max_waits = 3000 + while not any(current_url.startswith(r_url) for r_url in redirect_urls): + current_url = driver.current_url + + if num_waits > max_waits: + raise RedirectTimeoutError('timed out waiting for redirect') + time.sleep(redirect_wait_time_seconds) + num_waits += 1 + + token = oauth.fetch_token( + self._token_endpoint, + authorization_response=self.redirect_uri, + access_type='offline', + client_id=self.client_id, + include_client_id=True) + + print(token) + + # TODO: implement refresh token mode + self._token_save(token_dict=token) + def exchange_code_for_token(self, code: str, return_refresh_token: bool) -> dict: """Access token handler for AuthCode Workflow. @@ -422,7 +493,7 @@ def exchange_code_for_token(self, code: str, return_refresh_token: bool) -> dict # Make the request. response = requests.post( - url="https://api.tdameritrade.com/v1/oauth2/token", + url=self._token_endpoint, headers={'Content-Type': 'application/x-www-form-urlencoded'}, data=data )