Skip to content

Commit 2ac97e3

Browse files
committed
Update add script to add a range and prioritise questions asked from 9am-12pm and 8pm-10pm as well as questions asked during the weekend
1 parent 5c6ad01 commit 2ac97e3

File tree

1 file changed

+74
-24
lines changed

1 file changed

+74
-24
lines changed

core_backend/add_new_data_to_db.py

+74-24
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
QueryResponseContentDB,
2424
ResponseFeedbackDB,
2525
)
26-
from app.urgency_detection.models import UrgencyQueryDB
26+
from app.urgency_detection.models import UrgencyQueryDB, UrgencyResponseDB
2727
from app.users.models import UserDB
2828
from app.utils import get_key_hash
2929
from litellm import completion
@@ -50,6 +50,7 @@
5050
(ContentFeedbackDB, "feedback_datetime_utc"),
5151
(QueryResponseContentDB, "created_datetime_utc"),
5252
(UrgencyQueryDB, "message_datetime_utc"),
53+
(UrgencyResponseDB, "response_datetime_utc"),
5354
]
5455

5556
parser = argparse.ArgumentParser(
@@ -64,6 +65,7 @@
6465
--api-key <API_KEY> \
6566
--nb-workers 8 \
6667
--start-date 01-08-23
68+
--end-date 04-09-24
6769
6870
""",
6971
)
@@ -82,6 +84,16 @@
8284
help="Start date for the records in the format dd-mm-yy",
8385
required=False,
8486
)
87+
parser.add_argument(
88+
"--end-date",
89+
help="End date for the records in the format dd-mm-yy",
90+
required=False,
91+
)
92+
parser.add_argument(
93+
"--subset",
94+
help="Subset of the data to use for testing",
95+
required=False,
96+
)
8597
args = parser.parse_args()
8698

8799

@@ -281,24 +293,60 @@ def process_urgency_detection(_id: int, text: str) -> tuple | None:
281293
return None
282294

283295

284-
def create_random_datetime_from_string(start_date: datetime) -> datetime:
296+
def create_random_datetime(start_date: datetime, end_date: datetime) -> datetime:
285297
"""
286-
Create a random datetime from a date in the format "%d-%m-%y
287-
to today
298+
Create a random datetime from a date within a range
288299
"""
289300

290-
time_difference = datetime.now() - start_date
301+
time_difference = end_date - start_date
291302
random_number_of_days = random.randint(0, time_difference.days)
292303

293-
random_number_of_seconds = random.randint(0, 86399) # Number of seconds in one day
294-
304+
random_number_of_seconds = random.randint(0, 86399)
295305
random_datetime = start_date + timedelta(
296306
days=random_number_of_days, seconds=random_number_of_seconds
297307
)
298308
return random_datetime
299309

300310

301-
def update_date_of_records(models: list, random_dates: list, api_key: str) -> None:
311+
def is_within_time_range(date: datetime) -> bool:
312+
"""
313+
Helper function to check if the date is within desired time range.
314+
Prioritizing 9am-12pm and 8pm-10pm
315+
"""
316+
if 9 <= date.hour < 12 or 20 <= date.hour < 22:
317+
return True
318+
return False
319+
320+
321+
def generate_distributed_dates(n: int, start: datetime, end: datetime) -> list:
322+
"""
323+
Generate dates with a specific distribution for the records
324+
"""
325+
dates: list[datetime] = []
326+
while len(dates) < n:
327+
date = create_random_datetime(start, end)
328+
329+
# More dates on weekends
330+
if date.weekday() >= 5:
331+
332+
if (
333+
is_within_time_range(date) or random.random() < 0.4
334+
): # Within time range or 30% chance
335+
dates.append(date)
336+
else:
337+
if random.random() < 0.6:
338+
if is_within_time_range(date) or random.random() < 0.55:
339+
dates.append(date)
340+
341+
return dates
342+
343+
344+
def update_date_of_records(
345+
models: list,
346+
api_key: str,
347+
start_date: datetime,
348+
end_date: datetime,
349+
) -> None:
302350
"""
303351
Update the date of the records in the database
304352
"""
@@ -308,11 +356,7 @@ def update_date_of_records(models: list, random_dates: list, api_key: str) -> No
308356
select(UserDB).where(UserDB.hashed_api_key == hashed_token)
309357
).scalar_one()
310358
queries = [c for c in session.query(QueryDB).all() if c.user_id == user.user_id]
311-
if len(queries) > len(random_dates):
312-
random_dates = random_dates + [
313-
create_random_datetime_from_string(start_date)
314-
for _ in range(len(queries) - len(random_dates))
315-
]
359+
random_dates = generate_distributed_dates(len(queries), start_date, end_date)
316360
# Create a dictionary to map the query_id to the random date
317361
date_map_dic = {queries[i].query_id: random_dates[i] for i in range(len(queries))}
318362
for model in models:
@@ -323,8 +367,8 @@ def update_date_of_records(models: list, random_dates: list, api_key: str) -> No
323367

324368
for i, row in enumerate(rows):
325369
# Set the date attribute to the random date
326-
if hasattr(row, "query_id"):
327-
date = date_map_dic[row.query_id]
370+
if hasattr(row, "query_id") and model[0] != UrgencyQueryDB:
371+
date = date_map_dic.get(row.query_id, None)
328372
else:
329373
date = random_dates[i]
330374

@@ -351,17 +395,26 @@ def update_date_of_contents(date: datetime) -> None:
351395
NB_WORKERS = int(args.nb_workers) if args.nb_workers else 8
352396
API_KEY = args.api_key if args.api_key else ADMIN_API_KEY
353397

354-
date_string = args.start_date if args.start_date else "01-08-23"
398+
start_date_string = args.start_date if args.start_date else "01-08-23"
399+
end_date_string = args.end_date if args.end_date else None
355400
date_format = "%d-%m-%y"
356-
start_date = datetime.strptime(date_string, date_format)
401+
start_date = datetime.strptime(start_date_string, date_format)
402+
end_date = (
403+
datetime.strptime(end_date_string, date_format)
404+
if end_date_string
405+
else datetime.now()
406+
)
407+
assert end_date, "Invalid end date. Please provide a valid date. Format is dd-mm-yy"
357408
assert (
358-
start_date and start_date < datetime.now()
359-
), "Invalid start date. Please provide a valid start date."
409+
start_date and start_date < end_date
410+
), "Invalid start date. Please provide a valid start date. Format is dd-mm-yy"
360411

412+
subset = int(args.subset) if args.subset else None
361413
path = args.csv
362-
df = pd.read_csv(path)
414+
df = pd.read_csv(path, nrows=subset)
363415
saved_queries = defaultdict(list)
364416
print("Processing search queries...")
417+
365418
# Using multithreading to speed up the process
366419
with ThreadPoolExecutor(max_workers=NB_WORKERS) as executor:
367420
future_to_text = {
@@ -444,11 +497,8 @@ def update_date_of_contents(date: datetime) -> None:
444497
result = future.result()
445498
print("Urgency Detection successfully processed")
446499

447-
random_dates = [
448-
create_random_datetime_from_string(start_date) for _ in range(len(df))
449-
]
450500
print("Updating the date of the records...")
451-
update_date_of_records(MODELS, random_dates, API_KEY)
501+
update_date_of_records(MODELS, API_KEY, start_date, end_date)
452502

453503
print("Updating the date of the content records...")
454504
update_date_of_contents(start_date)

0 commit comments

Comments
 (0)