Skip to content

Commit 3978600

Browse files
committed
feat(pipeline): implement full ml pipeline
1 parent 45cf3ac commit 3978600

File tree

2 files changed

+64
-17
lines changed

2 files changed

+64
-17
lines changed

pipeline/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010

1111
class Model:
12+
# uniquely identifies versions of the model.
13+
# NOTE: update this if the model changed in any way.
14+
MODEL_ID = "yolo_detect_segment_v1"
15+
1216
def __init__(
1317
self,
1418
project_id: str = "flowmotion-4e268",

pipeline/pipeline.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,72 @@
1-
from api import APIClient
1+
#
2+
# Flowmotion
3+
# ML Pipeline
4+
# Entrypoint
5+
#
6+
7+
8+
from pathlib import Path
9+
from tempfile import mkdtemp
10+
from typing import cast
11+
from api import TrafficImageAPI
12+
from data import Congestion, Rating
13+
from db import DatabaseClient
214
from model import Model
315
from rating_validator import RatingValidator
416
from TrafficImage import TrafficImage
17+
from datetime import datetime, timezone
18+
from zoneinfo import ZoneInfo
19+
20+
21+
def datetime_sgt() -> datetime:
22+
"""Get the current timestamp in the SGT timezone."""
23+
return datetime.now(ZoneInfo("Asia/Singapore"))
24+
525

626
if __name__ == "__main__":
7-
apiclient = APIClient("https://api.data.gov.sg/v1/transport/traffic-images")
8-
active_cameraIDs = apiclient.camera_id_array
9-
current_traffic_camera_objects = [] # array of TrafficImage objects
27+
# fetch camera metadata & images from traffic image api
28+
api = TrafficImageAPI()
29+
cameras = api.get_cameras()
30+
image_dir = Path(mkdtemp())
31+
image_paths = api.get_images(cameras, image_dir)
1032

1133
# populating array of TrafficImage objects
12-
for camera_id in active_cameraIDs:
13-
image_url = apiclient.extract_image(camera_id)
14-
longitude, latitude = apiclient.extract_latlon(camera_id)
15-
traffic_camera_obj = TrafficImage(
16-
camera_id=camera_id, image=image_url, longitude=longitude, latitude=latitude
34+
traffic_images = [] # type: list[TrafficImage]
35+
for camera, image_path in zip(cameras, image_paths):
36+
traffic_images.append(
37+
TrafficImage(
38+
image=str(image_path),
39+
camera_id=camera.id,
40+
longitude=camera.location.longitude,
41+
latitude=camera.location.latitude,
42+
)
1743
)
18-
current_traffic_camera_objects.append(traffic_camera_obj)
1944

20-
# run model
45+
# perform inference on model for traffic congestion rating
2146
active_model = Model()
22-
active_model.predict(current_traffic_camera_objects)
23-
47+
active_model.predict(traffic_images)
48+
rated_on = datetime_sgt()
2449
# validate model output
25-
for object in current_traffic_camera_objects:
26-
validator = RatingValidator(object)
27-
validator.validate()
50+
for traffic_image in traffic_images:
51+
RatingValidator(traffic_image).validate()
52+
53+
# construct congestion with model's traffic congestion rating
54+
congestions = [] # type: list[Congestion]
55+
updated_on = datetime_sgt()
56+
for camera, traffic_image in zip(cameras, traffic_images):
57+
congestions.append(
58+
Congestion(
59+
camera=camera,
60+
rating=Rating(
61+
model_id=Model.MODEL_ID,
62+
rated_on=rated_on,
63+
value=cast(float, traffic_image.congestion_rating),
64+
),
65+
updated_on=updated_on,
66+
),
67+
)
2868

29-
# iterate through current_traffic_camera_objects and populate json for firebase storage
69+
# write congestions to the database
70+
db = DatabaseClient()
71+
for congestion in congestions:
72+
db.insert("congestions", congestion)

0 commit comments

Comments
 (0)