Skip to content

Commit 3431c58

Browse files
Add GenAI chat feature
1 parent 205d3e4 commit 3431c58

File tree

2 files changed

+90
-5
lines changed

2 files changed

+90
-5
lines changed

pyalgotrading/algobulls/api.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def __init__(self, connection):
3030
self.__key_papertrading = {} # strategy-cstc_id mapping
3131
self.__key_realtrading = {} # strategy-cstc_id mapping
3232
self.pattern = re.compile(r'(?<!^)(?=[A-Z])')
33+
self.genai_api_key = None
34+
self.genai_session_id = None
3335

3436
def __convert(self, _dict):
3537
# Helps convert _dict keys from camelcase to snakecase
@@ -473,3 +475,47 @@ def get_reports(self, strategy_code: str, trading_type: TradingType, report_type
473475
response = self._send_request(endpoint=endpoint, params=params)
474476

475477
return response
478+
479+
def get_genai(self, user_prompt: str, session_id: int, chat_gpt_model: str = ''):
480+
"""
481+
Fetch GenAI response.
482+
483+
Args:
484+
user_prompt: User question
485+
session_id: Session id of the GenAI session
486+
chat_gpt_model: Chat gpt model name
487+
Returns:
488+
GenAI response
489+
490+
Info: ENDPOINT
491+
`GET` v1/build/python/genai Get GenAI response
492+
"""
493+
endpoint = 'v1/build/python/genai'
494+
params = {"userPrompt": user_prompt, 'sessionId': self.genai_session_id, 'openaiApiKey': self.genai_api_key, 'chat_gpt_model': chat_gpt_model}
495+
response = self._send_request(endpoint=endpoint, params=params)
496+
497+
return response
498+
499+
def get_genai_response(self):
500+
"""
501+
Fetch GenAI response.
502+
503+
Args:
504+
Returns:
505+
GenAI response for current session. Last active session is used when session_id is None.
506+
507+
Info: ENDPOINT
508+
`GET` v1/build/python/genai/response Pooling API to get response in case of timeout
509+
"""
510+
endpoint = 'v1/build/python/genai/response'
511+
params = {'sessionId': self.genai_session_id}
512+
response = self._send_request(endpoint=endpoint, params=params)
513+
514+
return response
515+
516+
def get_genai_sessions(self, page_no):
517+
endpoint = 'v1/build/python/genai/sessions'
518+
params = {'sessionId': self.genai_session_id}
519+
response = self._send_request(endpoint=endpoint, params=params)
520+
521+
return response

pyalgotrading/algobulls/connection.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,54 @@ def set_generative_ai_keys(self, api_key, secret_key, token_key):
9090
# todo: learn about different generative AIs and how many and what keys are required ?
9191
# also confirm with backend the format of API keys to be received
9292
"""
93-
93+
assert isinstance(api_key, str), f'Argument "api_key" should be a string'
94+
self.api.openai_key = api_key
9495
return 'SUCCESS' or 'FAILURE'
9596

96-
def generate_strategy(self):
97-
input_prompt = str(input())
97+
def start_chat(self, chat_gpt_model):
98+
while True:
99+
user_prompt = str(input())
100+
if user_prompt.lower() == 'exit':
101+
print("Thanks for the chat")
102+
return
103+
104+
response = self.api.get_genai(user_prompt, chat_gpt_model)
105+
while response['status_code'] == 504:
106+
response = self.api.get_genai_response()
107+
108+
print(f"GenAI: {response['message']}")
109+
110+
def continue_from_previous_session(self, page_no):
111+
"""
112+
display previous sessions
113+
Returns:
98114
99-
# call the api
115+
"""
116+
customer_genai_sessions = self.api.get_genai_sessions(page_no)
117+
for i, session in enumerate(customer_genai_sessions):
118+
print(f"Session {i}: ID: {session['id']}, Started: {session['last_user_prompt']}")
100119

101-
return # strategy in strings
120+
if len(customer_genai_sessions) < 20:
121+
print("End")
122+
else:
123+
print(f"Type 'next' to view the next 20 sessions.")
124+
125+
user_input = input("Enter session number or 'next': ")
126+
if user_input.lower() == "next" and len(customer_genai_sessions) > 20:
127+
self.continue_from_previous_session(page_no=page_no + 1)
128+
elif user_input.isdigit() and 1 <= int(user_input) <= len(customer_genai_sessions):
129+
selected_session_index = page_no + int(user_input) - 1
130+
selected_session_id = customer_genai_sessions[selected_session_index]["id"]
131+
self.api.genai_api_key = selected_session_id
132+
133+
def initiate_chat(self, start_fresh=None, chat_gpt_model=None):
134+
if start_fresh:
135+
# reset session
136+
self.api.genai_api_key = None
137+
elif start_fresh is not None:
138+
self.continue_from_previous_session(page_no=1)
139+
140+
self.start_chat(chat_gpt_model)
102141

103142
def save_latest_generated_strategy(self):
104143
pass

0 commit comments

Comments
 (0)