|
1 |
| -from api import APIClient |
| 1 | +# |
| 2 | +# Flowmotion |
| 3 | +# ML Pipeline |
| 4 | +# Entrypoint |
| 5 | +# |
| 6 | + |
| 7 | + |
| 8 | +from pathlib import Path |
| 9 | +from tempfile import mkdtemp |
| 10 | +from typing import cast |
| 11 | +from api import TrafficImageAPI |
| 12 | +from data import Congestion, Rating |
| 13 | +from db import DatabaseClient |
2 | 14 | from model import Model
|
3 | 15 | from rating_validator import RatingValidator
|
4 | 16 | from TrafficImage import TrafficImage
|
| 17 | +from datetime import datetime, timezone |
| 18 | +from zoneinfo import ZoneInfo |
| 19 | + |
| 20 | + |
| 21 | +def datetime_sgt() -> datetime: |
| 22 | + """Get the current timestamp in the SGT timezone.""" |
| 23 | + return datetime.now(ZoneInfo("Asia/Singapore")) |
| 24 | + |
5 | 25 |
|
6 | 26 | if __name__ == "__main__":
|
7 |
| - apiclient = APIClient("https://api.data.gov.sg/v1/transport/traffic-images") |
8 |
| - active_cameraIDs = apiclient.camera_id_array |
9 |
| - current_traffic_camera_objects = [] # array of TrafficImage objects |
| 27 | + # fetch camera metadata & images from traffic image api |
| 28 | + api = TrafficImageAPI() |
| 29 | + cameras = api.get_cameras() |
| 30 | + image_dir = Path(mkdtemp()) |
| 31 | + image_paths = api.get_images(cameras, image_dir) |
10 | 32 |
|
11 | 33 | # populating array of TrafficImage objects
|
12 |
| - for camera_id in active_cameraIDs: |
13 |
| - image_url = apiclient.extract_image(camera_id) |
14 |
| - longitude, latitude = apiclient.extract_latlon(camera_id) |
15 |
| - traffic_camera_obj = TrafficImage( |
16 |
| - camera_id=camera_id, image=image_url, longitude=longitude, latitude=latitude |
| 34 | + traffic_images = [] # type: list[TrafficImage] |
| 35 | + for camera, image_path in zip(cameras, image_paths): |
| 36 | + traffic_images.append( |
| 37 | + TrafficImage( |
| 38 | + image=str(image_path), |
| 39 | + camera_id=camera.id, |
| 40 | + longitude=camera.location.longitude, |
| 41 | + latitude=camera.location.latitude, |
| 42 | + ) |
17 | 43 | )
|
18 |
| - current_traffic_camera_objects.append(traffic_camera_obj) |
19 | 44 |
|
20 |
| - # run model |
| 45 | + # perform inference on model for traffic congestion rating |
21 | 46 | active_model = Model()
|
22 |
| - active_model.predict(current_traffic_camera_objects) |
23 |
| - |
| 47 | + active_model.predict(traffic_images) |
| 48 | + rated_on = datetime_sgt() |
24 | 49 | # validate model output
|
25 |
| - for object in current_traffic_camera_objects: |
26 |
| - validator = RatingValidator(object) |
27 |
| - validator.validate() |
| 50 | + for traffic_image in traffic_images: |
| 51 | + RatingValidator(traffic_image).validate() |
| 52 | + |
| 53 | + # construct congestion with model's traffic congestion rating |
| 54 | + congestions = [] # type: list[Congestion] |
| 55 | + updated_on = datetime_sgt() |
| 56 | + for camera, traffic_image in zip(cameras, traffic_images): |
| 57 | + congestions.append( |
| 58 | + Congestion( |
| 59 | + camera=camera, |
| 60 | + rating=Rating( |
| 61 | + model_id=Model.MODEL_ID, |
| 62 | + rated_on=rated_on, |
| 63 | + value=cast(float, traffic_image.congestion_rating), |
| 64 | + ), |
| 65 | + updated_on=updated_on, |
| 66 | + ), |
| 67 | + ) |
28 | 68 |
|
29 |
| - # iterate through current_traffic_camera_objects and populate json for firebase storage |
| 69 | + # write congestions to the database |
| 70 | + db = DatabaseClient() |
| 71 | + for congestion in congestions: |
| 72 | + db.insert("congestions", congestion) |
0 commit comments