Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions python/idsse_common/idsse/common/protocol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from abc import abstractmethod, ABC
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta, UTC

from .path_builder import PathBuilder
Expand Down Expand Up @@ -91,13 +92,15 @@ def check_for(self, issue: datetime, valid: datetime, **kwargs) -> tuple[datetim
return valid, os.path.join(dir_path, fname)
return None

# pylint: disable=too-many-arguments
def get_issues(
self,
num_issues: int = 1,
issue_start: datetime | None = None,
issue_end: datetime | None = None,
time_delta: timedelta = timedelta(hours=1),
**kwargs
max_workers: int = 24,
**kwargs,
) -> Sequence[datetime]:
"""Determine the available issue date/times

Expand All @@ -106,6 +109,8 @@ def get_issues(
issue_start (datetime, optional): The oldest date/time to look for. Defaults to None.
issue_end (datetime): The newest date/time to look for. Defaults to now (UTC).
time_delta (timedelta): The time step size. Defaults to 1 hour.
max_workers (int): The number of Python threads to use to make AWS ls() calls.
Defaults to 24, which is reasonable. More threads will not necessarily run faster.
kwargs: Additional arguments, e.g. region

Returns:
Expand All @@ -125,26 +130,37 @@ def get_issues(
if time_delta > zero_time_delta:
time_delta = timedelta(seconds=-1.0 * time_delta.total_seconds())
datetimes = datetime_gen(issue_end, time_delta)
for issue_dt in datetimes:
if issue_start and issue_dt < issue_start:
break

# trim list of datetimes to requested length, then build list of unique datetimes
# that are confirmed to exist in AWS. ls() calls of each issue_dt folder happen in parallel
issue_filepaths = [
(dt, self.path_builder.build_dir(issue=dt, **kwargs))
for dt in list(datetimes)[:num_issues]
]
with ThreadPoolExecutor(max_workers, "AwsLsThread") as pool:
futures = [
pool.submit(self._get_issues, dir_path, num_issues)
for dir_path in [
dir_path
for (dt, dir_path) in issue_filepaths
if not (issue_start and dt < issue_start)
]
]
for future in as_completed(futures):
try:
dir_path = self.path_builder.build_dir(issue=issue_dt, **kwargs)
issues_set.update(self._get_issues(dir_path, num_issues))
if num_issues and len(issues_set) >= num_issues:
break
issues_in_aws = future.result()
issues_set.update(issues_in_aws)
except PermissionError:
pass
if None in issues_set:
issues_set.remove(None)
return sorted(issues_set)[:num_issues]
pass # last valid_dt wasn't quite available on AWS yet; skip that issue_dt

return list(issues_set)

def get_valids(
self,
issue: datetime,
valid_start: datetime | None = None,
valid_end: datetime | None = None,
**kwargs
**kwargs,
) -> Sequence[tuple[datetime, str]]:
"""Get all objects consistent with the passed issue date/time and filter by valid range

Expand Down Expand Up @@ -209,14 +225,13 @@ def _get_issues(self, dir_path: str, num_issues: int = 1) -> set[datetime]:
issues_set: set[datetime] = set()
# sort files alphabetically in reverse; this should give us the longest lead time first
# which is more indicative that the issueDt is fully available on this server
filepaths = sorted(
final_valid_filepaths = sorted(
(f for f in self.ls(dir_path) if f.endswith(self.path_builder.file_ext)), reverse=True
)
for file_path in filepaths:
for valid_file_path in final_valid_filepaths:
try:
issues_set.add(self.path_builder.get_issue(file_path))
if num_issues and len(issues_set) >= num_issues:
break
if issue_dt := self.path_builder.get_issue(valid_file_path):
issues_set.add(issue_dt)
except ValueError: # Ignore invalid filepaths...
pass
return issues_set
return sorted(list(issues_set), reverse=True)[:num_issues]
7 changes: 2 additions & 5 deletions python/idsse_common/idsse/common/rabbitmq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ChannelClosed,
ChannelWrongStateError,
ConnectionClosed,
StreamLostError,
)
from pika.frame import Method
from pika.spec import Basic
Expand Down Expand Up @@ -246,6 +247,7 @@ def run(self):
ConnectionResetError,
ChannelClosed,
ChannelWrongStateError,
StreamLostError,
) as exc:
_logger.warning(
"RabbitMQ connection closed unexpectedly, reconnecting now. Exc: [%s] %s",
Expand Down Expand Up @@ -291,11 +293,6 @@ def blocking_publish(
publisher is configured to confirm delivery will return False if
failed to confirm.
"""
if not self.channel.is_open:
# somehow RabbitMQ channel closed itself. Forceably create new connection/channel
logger.warning("Attempt to publish to closed connection. Reconnecting to RabbitMQ now")
self.channel = self._connect()

return blocking_publish(
self.channel, self._exch, RabbitMqMessage(message, properties, route_key), self._queue
)
Expand Down
1 change: 0 additions & 1 deletion python/idsse_common/idsse/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def datetime_gen(
max_num = min(max_num, dt_cnt) if max_num else dt_cnt

for i in range(0, max_num):
logger.debug("dt generator %d/%d", i, max_num)
yield dt_start + time_delta * i


Expand Down
Loading