Skip to content

Commit 9b902b6

Browse files
authored
feat(pipeline): implement ML pipeline end to end (#15)
* revert: api client to previous implementation * refactor(pipeline)!: TrafficImageAPI.get_images() to download image files to disk * feat(pipeline): implement full ml pipeline * feat(pipeline): add TrafficImage.from_camera() factory method * feat(pipeline): TrafficImage to track processed_on on timestamp * refactor(pipeline)!: move TrafficImage to data.py to break circular import * chore(pipeline): fix spelling * fix(pipeline): warning about async fixture test scope * feat(pipeline): add .from_traffic_image() methods to Congestion & Rating models * refactor(pipeline): TrafficImageAPI.get_traffic_images() to return TrafficImage instances * fix(pipeline): google cloud project id not set in db client
1 parent e3e938e commit 9b902b6

15 files changed

+377
-243
lines changed

pipeline/TrafficImage.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

pipeline/api.py

Lines changed: 86 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,86 @@
1-
import requests
2-
3-
4-
class APIClient:
5-
def __init__(self, url="https://api.data.gov.sg/v1/transport/traffic-images"):
6-
self.url = url
7-
self.timestamp = None
8-
self.api_status = "Unverified"
9-
self.camera_id_array = []
10-
11-
# Get API response
12-
response = requests.get(self.url)
13-
response_json = response.json()
14-
self.metadata = response_json
15-
16-
# Get and set API status
17-
self.api_status = self.metadata["api_info"]["status"]
18-
19-
# Get and set timestamp
20-
self.timestamp = self.metadata["items"][0]["timestamp"]
21-
22-
print(f"The API status is: {self.api_status}")
23-
print(f"The API was called at: {self.timestamp}")
24-
25-
for item in self.metadata["items"]:
26-
for camera in item["cameras"]:
27-
self.camera_id_array.append(camera["camera_id"])
28-
29-
def extract_image(self, camera_id):
30-
# Loop through the items and cameras to find the correct camera_id
31-
for item in self.metadata["items"]:
32-
for camera in item["cameras"]:
33-
if camera["camera_id"] == str(camera_id):
34-
return camera[
35-
"image"
36-
] # Return the image URL if the camera ID matches
37-
# If camera ID is not found
38-
raise RuntimeError(f"Camera ID {camera_id} not found.")
39-
40-
def extract_latlon(self, camera_id):
41-
for item in self.metadata["items"]:
42-
for camera in item["cameras"]:
43-
if camera["camera_id"] == str(camera_id):
44-
longitude = camera["location"]["longitude"]
45-
latitude = camera["location"]["latitude"]
46-
return (
47-
longitude,
48-
latitude,
49-
) # Return both longitude and latitude as a tuple
50-
# If camera ID is not found
51-
raise RuntimeError(f"Camera ID {camera_id} not found.")
1+
#
2+
# Flowmotion
3+
# Pipeline
4+
# Traffic Images API Client
5+
#
6+
7+
import asyncio
8+
from pathlib import Path
9+
10+
import httpx
11+
12+
from data import Camera, Location, TrafficImage
13+
14+
15+
class TrafficImageAPI:
16+
"""Data.gov.sg Traffic Images API Client."""
17+
18+
API_URL = "https://api.data.gov.sg/v1/transport/traffic-images"
19+
20+
def __init__(self, api_url: str = API_URL):
21+
self.api_url = api_url
22+
self._sync = httpx.Client()
23+
self._async = httpx.AsyncClient()
24+
25+
def get_cameras(self) -> list[Camera]:
26+
"""Get Traffic Camera metadata from traffic images API endpoint.
27+
28+
Returns:
29+
Parsed traffic camera metadata.
30+
"""
31+
# fetch traffic-images api endpoint
32+
response = self._sync.get(self.API_URL)
33+
response.raise_for_status()
34+
meta = response.json()
35+
return parse_cameras(meta)
36+
37+
def get_traffic_images(
38+
self, cameras: list[Camera], image_dir: Path
39+
) -> list[TrafficImage]:
40+
"""Save Traffic Camera images from given Cameras into image_dir.
41+
Creates image_dir if it does not already exist.
42+
43+
Args:
44+
cameras:
45+
List of traffic cameras to retrieve traffic images from.
46+
image_dir:
47+
Path the image directory to write retrieved images.
48+
Returns:
49+
List of retrieve Traffic Images.
50+
"""
51+
# ensure image directory exists
52+
image_dir.mkdir(parents=True, exist_ok=True)
53+
54+
async def fetch(camera: Camera) -> TrafficImage:
55+
response = await self._async.get(camera.image_url)
56+
# write image bytes to image file on disk
57+
image_path = image_dir / f"{camera.id}.jpg"
58+
with open(image_path, "wb") as f:
59+
for chunk in response.iter_bytes():
60+
f.write(chunk)
61+
62+
return TrafficImage.from_camera(camera, image_path)
63+
64+
async def fetch_all() -> list[TrafficImage]:
65+
# perform all image fetches asynchronously
66+
return await asyncio.gather(*[fetch(camera) for camera in cameras])
67+
68+
return asyncio.run(fetch_all())
69+
70+
71+
def parse_cameras(meta: dict) -> list[Camera]:
72+
meta = meta["items"][0]
73+
retrieved_on = meta["timestamp"]
74+
return [
75+
Camera(
76+
id=c["camera_id"],
77+
retrieved_on=retrieved_on,
78+
captured_on=c["timestamp"],
79+
image_url=c["image"],
80+
location=Location(
81+
longitude=c["location"]["longitude"],
82+
latitude=c["location"]["latitude"],
83+
),
84+
)
85+
for c in meta["cameras"]
86+
]

pipeline/conftest.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#
2+
# Flowmotion
3+
# Test Fixtures
4+
#
5+
6+
from datetime import datetime
7+
8+
import pytest
9+
10+
from data import Camera, Location, TrafficImage
11+
from model import Model
12+
13+
14+
@pytest.fixture
15+
def camera() -> Camera:
16+
return Camera(
17+
id="1001",
18+
image_url="https://images.data.gov.sg/api/traffic/1001.jpg",
19+
captured_on=datetime(2024, 9, 27, 8, 30, 0),
20+
retrieved_on=datetime(2024, 9, 27, 8, 31, 0),
21+
location=Location(longitude=103.851959, latitude=1.290270),
22+
)
23+
24+
25+
# Fixture for TrafficImage instance
26+
@pytest.fixture
27+
def traffic_image():
28+
return TrafficImage(
29+
image="some_image_url",
30+
processed=True,
31+
congestion_rating=0.5,
32+
camera_id="camera_123",
33+
longitude=103.851959,
34+
latitude=1.290270,
35+
processed_on=datetime(2024, 9, 27, 8, 30, 0),
36+
model_id=Model.MODEL_ID,
37+
)

pipeline/data.py

Lines changed: 112 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,16 @@
55
#
66

77
import json
8+
import math
89
from datetime import datetime
10+
from pathlib import Path
11+
from types import NotImplementedType
12+
from typing import Optional, cast
913

1014
from pydantic import BaseModel
1115

16+
from timetools import datetime_sgt
17+
1218

1319
class Location(BaseModel):
1420
"""Geolocation consisting of longitude and latitude."""
@@ -17,14 +23,6 @@ class Location(BaseModel):
1723
latitude: float
1824

1925

20-
class Rating(BaseModel):
21-
"""Traffic Congestion rating performed by a model"""
22-
23-
rated_on: datetime
24-
model_id: str
25-
value: float
26-
27-
2826
class Camera(BaseModel):
2927
"""Traffic Camera capturing traffic images."""
3028

@@ -35,13 +33,119 @@ class Camera(BaseModel):
3533
location: Location
3634

3735

36+
class TrafficImage:
37+
"""Traffic Image to be rated for congestion.
38+
39+
Attributes:
40+
image: URL that retrieves Traffic camera image.
41+
processed: Whether this TrafficImage instance has been processed.
42+
congestion_rating: 0-1 Congestion Rating
43+
camera_id: ID of the camera that captured this image
44+
longitude: Longitude of the camera that captured this image
45+
latitude: Latitude of the camera that captured this image
46+
processed_on: Timestamp when this TrafficImage instance has been processed.
47+
model_id: Unique ID to identify the version of the model the performed the rating.
48+
"""
49+
50+
def __init__(
51+
self,
52+
image,
53+
camera_id,
54+
longitude,
55+
latitude,
56+
processed=False,
57+
congestion_rating=None,
58+
processed_on: Optional[datetime] = None,
59+
model_id: Optional[str] = None,
60+
):
61+
self.image = image
62+
self.processed = processed
63+
self.congestion_rating = congestion_rating
64+
self.camera_id = camera_id
65+
self.longitude = longitude
66+
self.latitude = latitude
67+
self.processed_on = processed_on
68+
self.model_id = model_id
69+
70+
@classmethod
71+
def from_camera(cls, camera: Camera, image_path: Path) -> "TrafficImage":
72+
"""Create TrafficImage from camera & its image path.
73+
74+
Args:
75+
camera:
76+
Camera model to create TrafficImage from.
77+
image_path:
78+
Path to the retrieved captured image from the Camera.
79+
Returns:
80+
Constructed TrafficImage
81+
"""
82+
return cls(
83+
image=str(image_path),
84+
camera_id=camera.id,
85+
longitude=camera.location.longitude,
86+
latitude=camera.location.latitude,
87+
)
88+
89+
def set_processed(self, congestion_rating: float, model_id: str):
90+
self.processed = True
91+
self.congestion_rating = congestion_rating
92+
self.processed_on = datetime_sgt()
93+
self.model_id = model_id
94+
95+
96+
class Rating(BaseModel):
97+
"""Traffic Congestion rating performed by a model"""
98+
99+
rated_on: datetime
100+
model_id: str
101+
value: float
102+
103+
@classmethod
104+
def from_traffic_image(cls, image: TrafficImage) -> "Rating":
105+
if (
106+
image.processed_on is None
107+
or image.congestion_rating is None
108+
or image.model_id is None
109+
):
110+
__import__("pprint").pprint(image.__dict__)
111+
raise ValueError(
112+
"Invalid TrafficImage: Either 'processed_on' or 'congestion_rating' or 'model_id' is None."
113+
)
114+
115+
return cls(
116+
rated_on=image.processed_on,
117+
value=image.congestion_rating,
118+
model_id=image.model_id,
119+
)
120+
121+
def equal(self, other: object) -> bool | NotImplementedType:
122+
if not isinstance(other, Rating):
123+
return NotImplemented
124+
other = cast(Rating, other)
125+
return (
126+
self.model_id == other.model_id
127+
and math.isclose(self.value, other.value)
128+
and self.rated_on == other.rated_on
129+
)
130+
131+
38132
class Congestion(BaseModel):
39133
"""Traffic Congestion data."""
40134

41135
camera: Camera
42136
rating: Rating
43137
updated_on: datetime
44138

139+
@classmethod
140+
def from_traffic_image(
141+
cls, image: TrafficImage, camera: Camera, updated_on: datetime
142+
) -> "Congestion":
143+
return cls(
144+
camera=camera,
145+
rating=Rating.from_traffic_image(image),
146+
updated_on=updated_on,
147+
)
148+
45149

46150
def to_json_dict(model: BaseModel):
47151
"""Convert given pydantic model into the its JSON dict representation"""

pipeline/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self) -> None:
2323
Uses Google Application Default credentials with authenticate DB requests.
2424
See https://firebase.google.com/docs/admin/setup#initialize-sdk.
2525
"""
26-
app = firebase_admin.initialize_app()
26+
app = firebase_admin.initialize_app(options={"projectId": "flowmotion-4e268"})
2727
self._db = firestore.client(app)
2828

2929
def insert(self, table: str, data: BaseModel) -> str:

pipeline/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
from google.cloud import storage
66
from ultralytics import YOLO
77

8-
from TrafficImage import TrafficImage
8+
from data import TrafficImage
99

1010

1111
class Model:
12+
# uniquely identifies versions of the model.
13+
# NOTE: update this if the model changed in any way.
14+
MODEL_ID = "yolo_detect_segment_v1"
15+
1216
def __init__(
1317
self,
1418
project_id: str = "flowmotion-4e268",
@@ -58,7 +62,7 @@ def predict(self, images: list[TrafficImage]):
5862
car_results = self.car_model(img)
5963
if len(car_results) == 0 or len(car_results[0].boxes) == 0:
6064
print("No cars detected")
61-
image.set_processed(0.0)
65+
image.set_processed(0.0, Model.MODEL_ID)
6266
continue
6367

6468
# Access car bounding boxes (x, y, w, h)
@@ -95,4 +99,4 @@ def predict(self, images: list[TrafficImage]):
9599
print(f"Congestion rating: {congestion_rating}")
96100

97101
# Set the predicted congestion rating
98-
image.set_processed(congestion_rating)
102+
image.set_processed(congestion_rating, Model.MODEL_ID)

0 commit comments

Comments
 (0)