diff --git a/pipeline/TrafficImage.py b/pipeline/TrafficImage.py deleted file mode 100644 index c904edc..0000000 --- a/pipeline/TrafficImage.py +++ /dev/null @@ -1,31 +0,0 @@ -class TrafficImage: - """Traffic Image to be rated for congestion. - - Attributes: - image: URL that retrieves Traffic camera image. - processed: Whether this TrafficImage instance has been processed. - congestion_rating: 0-1 Congestion Rating - camera_id: ID of the camera that captured this image - longitude: Longitude of the camera that captured this image - latitude: Latitude of the camera that captured this image - """ - - def __init__( - self, - image, - camera_id, - longitude, - latitude, - processed=False, - congestion_rating=None, - ): - self.image = image - self.processed = processed - self.congestion_rating = congestion_rating - self.camera_id = camera_id - self.longitude = longitude - self.latitude = latitude - - def set_processed(self, congestion_rating): - self.processed = True - self.congestion_rating = congestion_rating diff --git a/pipeline/api.py b/pipeline/api.py index 8f6ed67..adf2fc8 100644 --- a/pipeline/api.py +++ b/pipeline/api.py @@ -1,51 +1,86 @@ -import requests - - -class APIClient: - def __init__(self, url="https://api.data.gov.sg/v1/transport/traffic-images"): - self.url = url - self.timestamp = None - self.api_status = "Unverified" - self.camera_id_array = [] - - # Get API response - response = requests.get(self.url) - response_json = response.json() - self.metadata = response_json - - # Get and set API status - self.api_status = self.metadata["api_info"]["status"] - - # Get and set timestamp - self.timestamp = self.metadata["items"][0]["timestamp"] - - print(f"The API status is: {self.api_status}") - print(f"The API was called at: {self.timestamp}") - - for item in self.metadata["items"]: - for camera in item["cameras"]: - self.camera_id_array.append(camera["camera_id"]) - - def extract_image(self, camera_id): - # Loop through the items and cameras to find the correct camera_id - for item in self.metadata["items"]: - for camera in item["cameras"]: - if camera["camera_id"] == str(camera_id): - return camera[ - "image" - ] # Return the image URL if the camera ID matches - # If camera ID is not found - raise RuntimeError(f"Camera ID {camera_id} not found.") - - def extract_latlon(self, camera_id): - for item in self.metadata["items"]: - for camera in item["cameras"]: - if camera["camera_id"] == str(camera_id): - longitude = camera["location"]["longitude"] - latitude = camera["location"]["latitude"] - return ( - longitude, - latitude, - ) # Return both longitude and latitude as a tuple - # If camera ID is not found - raise RuntimeError(f"Camera ID {camera_id} not found.") +# +# Flowmotion +# Pipeline +# Traffic Images API Client +# + +import asyncio +from pathlib import Path + +import httpx + +from data import Camera, Location, TrafficImage + + +class TrafficImageAPI: + """Data.gov.sg Traffic Images API Client.""" + + API_URL = "https://api.data.gov.sg/v1/transport/traffic-images" + + def __init__(self, api_url: str = API_URL): + self.api_url = api_url + self._sync = httpx.Client() + self._async = httpx.AsyncClient() + + def get_cameras(self) -> list[Camera]: + """Get Traffic Camera metadata from traffic images API endpoint. + + Returns: + Parsed traffic camera metadata. + """ + # fetch traffic-images api endpoint + response = self._sync.get(self.API_URL) + response.raise_for_status() + meta = response.json() + return parse_cameras(meta) + + def get_traffic_images( + self, cameras: list[Camera], image_dir: Path + ) -> list[TrafficImage]: + """Save Traffic Camera images from given Cameras into image_dir. + Creates image_dir if it does not already exist. + + Args: + cameras: + List of traffic cameras to retrieve traffic images from. + image_dir: + Path the image directory to write retrieved images. + Returns: + List of retrieve Traffic Images. + """ + # ensure image directory exists + image_dir.mkdir(parents=True, exist_ok=True) + + async def fetch(camera: Camera) -> TrafficImage: + response = await self._async.get(camera.image_url) + # write image bytes to image file on disk + image_path = image_dir / f"{camera.id}.jpg" + with open(image_path, "wb") as f: + for chunk in response.iter_bytes(): + f.write(chunk) + + return TrafficImage.from_camera(camera, image_path) + + async def fetch_all() -> list[TrafficImage]: + # perform all image fetches asynchronously + return await asyncio.gather(*[fetch(camera) for camera in cameras]) + + return asyncio.run(fetch_all()) + + +def parse_cameras(meta: dict) -> list[Camera]: + meta = meta["items"][0] + retrieved_on = meta["timestamp"] + return [ + Camera( + id=c["camera_id"], + retrieved_on=retrieved_on, + captured_on=c["timestamp"], + image_url=c["image"], + location=Location( + longitude=c["location"]["longitude"], + latitude=c["location"]["latitude"], + ), + ) + for c in meta["cameras"] + ] diff --git a/pipeline/conftest.py b/pipeline/conftest.py new file mode 100644 index 0000000..629923f --- /dev/null +++ b/pipeline/conftest.py @@ -0,0 +1,37 @@ +# +# Flowmotion +# Test Fixtures +# + +from datetime import datetime + +import pytest + +from data import Camera, Location, TrafficImage +from model import Model + + +@pytest.fixture +def camera() -> Camera: + return Camera( + id="1001", + image_url="https://images.data.gov.sg/api/traffic/1001.jpg", + captured_on=datetime(2024, 9, 27, 8, 30, 0), + retrieved_on=datetime(2024, 9, 27, 8, 31, 0), + location=Location(longitude=103.851959, latitude=1.290270), + ) + + +# Fixture for TrafficImage instance +@pytest.fixture +def traffic_image(): + return TrafficImage( + image="some_image_url", + processed=True, + congestion_rating=0.5, + camera_id="camera_123", + longitude=103.851959, + latitude=1.290270, + processed_on=datetime(2024, 9, 27, 8, 30, 0), + model_id=Model.MODEL_ID, + ) diff --git a/pipeline/data.py b/pipeline/data.py index 01b954a..35d0ec1 100644 --- a/pipeline/data.py +++ b/pipeline/data.py @@ -5,10 +5,16 @@ # import json +import math from datetime import datetime +from pathlib import Path +from types import NotImplementedType +from typing import Optional, cast from pydantic import BaseModel +from timetools import datetime_sgt + class Location(BaseModel): """Geolocation consisting of longitude and latitude.""" @@ -17,14 +23,6 @@ class Location(BaseModel): latitude: float -class Rating(BaseModel): - """Traffic Congestion rating performed by a model""" - - rated_on: datetime - model_id: str - value: float - - class Camera(BaseModel): """Traffic Camera capturing traffic images.""" @@ -35,6 +33,102 @@ class Camera(BaseModel): location: Location +class TrafficImage: + """Traffic Image to be rated for congestion. + + Attributes: + image: URL that retrieves Traffic camera image. + processed: Whether this TrafficImage instance has been processed. + congestion_rating: 0-1 Congestion Rating + camera_id: ID of the camera that captured this image + longitude: Longitude of the camera that captured this image + latitude: Latitude of the camera that captured this image + processed_on: Timestamp when this TrafficImage instance has been processed. + model_id: Unique ID to identify the version of the model the performed the rating. + """ + + def __init__( + self, + image, + camera_id, + longitude, + latitude, + processed=False, + congestion_rating=None, + processed_on: Optional[datetime] = None, + model_id: Optional[str] = None, + ): + self.image = image + self.processed = processed + self.congestion_rating = congestion_rating + self.camera_id = camera_id + self.longitude = longitude + self.latitude = latitude + self.processed_on = processed_on + self.model_id = model_id + + @classmethod + def from_camera(cls, camera: Camera, image_path: Path) -> "TrafficImage": + """Create TrafficImage from camera & its image path. + + Args: + camera: + Camera model to create TrafficImage from. + image_path: + Path to the retrieved captured image from the Camera. + Returns: + Constructed TrafficImage + """ + return cls( + image=str(image_path), + camera_id=camera.id, + longitude=camera.location.longitude, + latitude=camera.location.latitude, + ) + + def set_processed(self, congestion_rating: float, model_id: str): + self.processed = True + self.congestion_rating = congestion_rating + self.processed_on = datetime_sgt() + self.model_id = model_id + + +class Rating(BaseModel): + """Traffic Congestion rating performed by a model""" + + rated_on: datetime + model_id: str + value: float + + @classmethod + def from_traffic_image(cls, image: TrafficImage) -> "Rating": + if ( + image.processed_on is None + or image.congestion_rating is None + or image.model_id is None + ): + __import__("pprint").pprint(image.__dict__) + raise ValueError( + "Invalid TrafficImage: Either 'processed_on' or 'congestion_rating' or 'model_id' is None." + ) + + return cls( + rated_on=image.processed_on, + value=image.congestion_rating, + model_id=image.model_id, + ) + + def equal(self, other: object) -> bool | NotImplementedType: + if not isinstance(other, Rating): + return NotImplemented + other = cast(Rating, other) + return ( + self.model_id == other.model_id + and math.isclose(self.value, other.value) + and self.rated_on == other.rated_on + ) + + class Congestion(BaseModel): """Traffic Congestion data.""" @@ -42,6 +136,16 @@ class Congestion(BaseModel): rating: Rating updated_on: datetime + @classmethod + def from_traffic_image( + cls, image: TrafficImage, camera: Camera, updated_on: datetime + ) -> "Congestion": + return cls( + camera=camera, + rating=Rating.from_traffic_image(image), + updated_on=updated_on, + ) + def to_json_dict(model: BaseModel): """Convert given pydantic model into the its JSON dict representation""" diff --git a/pipeline/db.py b/pipeline/db.py index 8f46588..698a43a 100644 --- a/pipeline/db.py +++ b/pipeline/db.py @@ -23,7 +23,7 @@ def __init__(self) -> None: Uses Google Application Default credentials with authenticate DB requests. See https://firebase.google.com/docs/admin/setup#initialize-sdk. """ - app = firebase_admin.initialize_app() + app = firebase_admin.initialize_app(options={"projectId": "flowmotion-4e268"}) self._db = firestore.client(app) def insert(self, table: str, data: BaseModel) -> str: diff --git a/pipeline/model.py b/pipeline/model.py index e7ee3da..49d06de 100644 --- a/pipeline/model.py +++ b/pipeline/model.py @@ -5,10 +5,14 @@ from google.cloud import storage from ultralytics import YOLO -from TrafficImage import TrafficImage +from data import TrafficImage class Model: + # uniquely identifies versions of the model. + # NOTE: update this if the model changed in any way. + MODEL_ID = "yolo_detect_segment_v1" + def __init__( self, project_id: str = "flowmotion-4e268", @@ -58,7 +62,7 @@ def predict(self, images: list[TrafficImage]): car_results = self.car_model(img) if len(car_results) == 0 or len(car_results[0].boxes) == 0: print("No cars detected") - image.set_processed(0.0) + image.set_processed(0.0, Model.MODEL_ID) continue # Access car bounding boxes (x, y, w, h) @@ -95,4 +99,4 @@ def predict(self, images: list[TrafficImage]): print(f"Congestion rating: {congestion_rating}") # Set the predicted congestion rating - image.set_processed(congestion_rating) + image.set_processed(congestion_rating, Model.MODEL_ID) diff --git a/pipeline/pipeline.py b/pipeline/pipeline.py index 90a26f8..c14f284 100644 --- a/pipeline/pipeline.py +++ b/pipeline/pipeline.py @@ -1,29 +1,43 @@ -from api import APIClient +# +# Flowmotion +# ML Pipeline +# Entrypoint +# + + +from pathlib import Path +from tempfile import mkdtemp + +from api import TrafficImageAPI +from data import Congestion +from db import DatabaseClient from model import Model from rating_validator import RatingValidator -from TrafficImage import TrafficImage +from timetools import datetime_sgt if __name__ == "__main__": - apiclient = APIClient("https://api.data.gov.sg/v1/transport/traffic-images") - active_cameraIDs = apiclient.camera_id_array - current_traffic_camera_objects = [] # array of TrafficImage objects - - # populating array of TrafficImage objects - for camera_id in active_cameraIDs: - image_url = apiclient.extract_image(camera_id) - longitude, latitude = apiclient.extract_latlon(camera_id) - traffic_camera_obj = TrafficImage( - camera_id=camera_id, image=image_url, longitude=longitude, latitude=latitude - ) - current_traffic_camera_objects.append(traffic_camera_obj) + # fetch camera metadata & images from traffic image api + api = TrafficImageAPI() + cameras = api.get_cameras() + image_dir = Path(mkdtemp()) + traffic_images = api.get_traffic_images(cameras, image_dir) - # run model + # perform inference on model for traffic congestion rating active_model = Model() - active_model.predict(current_traffic_camera_objects) - + active_model.predict(traffic_images) # validate model output - for object in current_traffic_camera_objects: - validator = RatingValidator(object) - validator.validate() + for traffic_image in traffic_images: + RatingValidator(traffic_image).validate() + + # construct congestion with model's traffic congestion rating + congestions = [] # type: list[Congestion] + updated_on = datetime_sgt() + for camera, traffic_image in zip(cameras, traffic_images): + congestions.append( + Congestion.from_traffic_image(traffic_image, camera, updated_on) + ) - # iterate through current_traffic_camera_objects and populate json for firebase storage + # write congestions to the database + db = DatabaseClient() + for congestion in congestions: + db.insert("congestions", congestion) diff --git a/pipeline/pyproject.toml b/pipeline/pyproject.toml index 4304b8a..67528f0 100644 --- a/pipeline/pyproject.toml +++ b/pipeline/pyproject.toml @@ -6,3 +6,6 @@ [tool.isort] profile = "black" + +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "loop" diff --git a/pipeline/rating_validator.py b/pipeline/rating_validator.py index d9d0169..900072c 100644 --- a/pipeline/rating_validator.py +++ b/pipeline/rating_validator.py @@ -1,4 +1,4 @@ -from TrafficImage import TrafficImage +from data import TrafficImage class RatingValidator: @@ -14,7 +14,7 @@ def __init__(self, TrafficImageObject: TrafficImage = None): def validate(self) -> None: if self.traffic_image_object is None: - raise Exception("No Traffic Image Parsed Through Validatior!!!") + raise Exception("No Traffic Image Parsed Through Validator!!!") elif self.congestion_rating < 0 or self.congestion_rating > 1: raise Exception("Congestion Rating Is Invalid (< 0 or > 1)!!!") elif self.processed is False: diff --git a/pipeline/test_api.py b/pipeline/test_api.py index b670807..9a4c4f3 100644 --- a/pipeline/test_api.py +++ b/pipeline/test_api.py @@ -5,64 +5,42 @@ # import json +from io import SEEK_END from pathlib import Path +from shutil import rmtree +from tempfile import mkdtemp import pytest -from api import APIClient -from TrafficImage import TrafficImage +from api import TrafficImageAPI, parse_cameras +from data import Camera @pytest.fixture -def api() -> APIClient: - return APIClient(url="https://api.data.gov.sg/v1/transport/traffic-images") +def api() -> TrafficImageAPI: + return TrafficImageAPI() @pytest.fixture -def sample_camera_id() -> str: - # Load camera ID from a sample JSON file for testing +def cameras() -> list[Camera]: with open(Path(__file__).parent / "resources" / "traffic_images.json") as f: - cameras_data = json.loads(f.read()) - return cameras_data["items"][0]["cameras"][0][ - "camera_id" - ] # Grab a camera_id from the file + return parse_cameras(json.loads(f.read())) -def test_get_camera_ids(api: APIClient): - assert len(api.camera_id_array) > 0 # Ensure the APIClient fetched camera IDs +def test_get_cameras(api: TrafficImageAPI): + assert len(api.get_cameras()) > 0 -def test_extract_image(api: APIClient, sample_camera_id: str): - image_url = api.extract_image(sample_camera_id) - assert isinstance(image_url, str) - assert image_url.startswith("http") # Ensure a valid image URL is returned +def test_get_traffic_images(api: TrafficImageAPI, cameras: list[Camera]): + image_dir = mkdtemp() + traffic_images = api.get_traffic_images(cameras, Path(image_dir)) + assert len(traffic_images) > 0 + # check image paths are nonempty + for traffic_image in traffic_images: + with open(traffic_image.image) as f: + f.seek(0, SEEK_END) + assert f.tell() > 0 -def test_extract_latlon(api: APIClient, sample_camera_id: str): - latlon = api.extract_latlon(sample_camera_id) - assert isinstance(latlon, tuple) - assert len(latlon) == 2 # Ensure we get a tuple with (longitude, latitude) - longitude, latitude = latlon - assert isinstance(longitude, float) - assert isinstance(latitude, float) - - -def test_traffic_image_class(): - # Example test for TrafficImage class - traffic_image = TrafficImage( - image="some_image_url", - processed=False, - congestion_rating=None, - camera_id="camera_123", - longitude=103.851959, - latitude=1.290270, - ) - - assert traffic_image.image == "some_image_url" - assert traffic_image.camera_id == "camera_123" - assert not traffic_image.processed # Ensure it's not processed initially - - # Simulate processing the image - traffic_image.set_processed(0.5) - assert traffic_image.processed # Now it should be processed - assert traffic_image.congestion_rating == 0.5 + # clean up image files + rmtree(image_dir) diff --git a/pipeline/test_data.py b/pipeline/test_data.py index f6f2f85..2d4ff1c 100644 --- a/pipeline/test_data.py +++ b/pipeline/test_data.py @@ -9,29 +9,66 @@ from datetime import datetime from pathlib import Path +import pytest from jsonschema import validate -from data import Camera, Congestion, Location, Rating, to_json_dict +from data import Camera, Congestion, Rating, TrafficImage, to_json_dict +from model import Model CONGESTION_SCHEMA = Path(__file__).parent.parent / "schema" / "congestion.schema.json" -def test_congestion_json(): - congestion = Congestion( - camera=Camera( - id="1001", - image_url="https://images.data.gov.sg/api/traffic/1001.jpg", - captured_on=datetime(2024, 9, 27, 8, 30, 0), - retrieved_on=datetime(2024, 9, 27, 8, 31, 0), - location=Location(longitude=103.851959, latitude=1.290270), - ), +@pytest.fixture +def congestion(camera: Camera) -> Congestion: + return Congestion( + camera=camera, rating=Rating( - rated_on=datetime(2024, 9, 27, 8, 32, 0), model_id="v1.0", value=0.75 + rated_on=datetime(2024, 9, 27, 8, 32, 0), + model_id=Model.MODEL_ID, + value=0.75, ), updated_on=datetime(2024, 9, 27, 8, 33, 0), ) + +def test_congestion_json(congestion: Congestion): with open(CONGESTION_SCHEMA, "r") as f: schema = json.load(f) validate(to_json_dict(congestion), schema) + + +def test_congestion_from_traffic_image(camera: Camera): + traffic_image = TrafficImage.from_camera(camera, Path("image.png")) + traffic_image.set_processed(0.75, Model.MODEL_ID) + updated_on = datetime(2024, 9, 27, 8, 33, 0) + + actual = Congestion.from_traffic_image(traffic_image, camera, updated_on) + assert actual.camera == camera + assert actual.rating.equal(Rating.from_traffic_image(traffic_image)) + assert actual.updated_on == updated_on + + +def test_rating_from_traffic_image(traffic_image: TrafficImage): + assert Rating.from_traffic_image(traffic_image).equal( + Rating( + model_id=Model.MODEL_ID, + value=traffic_image.congestion_rating, # type: ignore + rated_on=traffic_image.processed_on, # type: ignore + ) + ) + + +def test_trafficimage_from_camera(camera: Camera): + image_path = Path("image.jpg") + traffic_image = TrafficImage.from_camera(camera, image_path) + assert traffic_image.__dict__ == { + "image": str(image_path), + "camera_id": camera.id, + "longitude": camera.location.longitude, + "latitude": camera.location.latitude, + "processed": False, + "processed_on": None, + "congestion_rating": None, + "model_id": None, + } diff --git a/pipeline/test_db.py b/pipeline/test_db.py index 6357fe3..fa02d00 100644 --- a/pipeline/test_db.py +++ b/pipeline/test_db.py @@ -10,7 +10,6 @@ # with firestore eg. by setting GOOGLE_APPLICATION_CREDENTIALS env var. -import os from uuid import uuid4 import pytest @@ -26,7 +25,6 @@ class Model(BaseModel): @pytest.fixture(scope="session") def db() -> DatabaseClient: - os.environ["GOOGLE_CLOUD_PROJECT"] = "flowmotion-4e268" return DatabaseClient() diff --git a/pipeline/test_model.py b/pipeline/test_model.py index 1ec5d81..1aa2d8a 100644 --- a/pipeline/test_model.py +++ b/pipeline/test_model.py @@ -9,72 +9,29 @@ # with firestore eg. by setting GOOGLE_APPLICATION_CREDENTIALS env var. -import os +from pathlib import Path from shutil import rmtree from tempfile import mkdtemp from typing import Any, Generator import pytest -import requests +from api import TrafficImageAPI +from data import TrafficImage from model import Model -from TrafficImage import TrafficImage @pytest.fixture(scope="session") def traffic_images() -> Generator[list[TrafficImage], Any, Any]: - """Fixture downloads traffic camera images from the Traffic Images API into temporary directory. - Yields Traffic Image instances for each traffic image retrieved. - """ # create temporary directory for traffic images - save_dir = mkdtemp() + api = TrafficImageAPI() + image_dir = Path(mkdtemp()) + cameras = api.get_cameras() - # API URL to fetch images - api_url = "https://api.data.gov.sg/v1/transport/traffic-images" - response = requests.get(api_url) - if response.status_code == 200: - data = response.json() + yield api.get_traffic_images(cameras, image_dir) - camera_ids = [] - image_urls = [] - locations = [] - - for item in data["items"]: - for camera in item["cameras"]: - camera_ids.append(camera["camera_id"]) - image_urls.append(camera["image"]) - locations.append( - (camera["location"]["latitude"], camera["location"]["longitude"]) - ) - file_paths = [] - for idx, (camera_id, url, location) in enumerate( - zip(camera_ids, image_urls, locations) - ): - response = requests.get(url, stream=True) - if response.status_code == 200: - file_path = os.path.join(save_dir, f"image_{camera_id}.jpg") - with open(file_path, "wb") as file: - for chunk in response.iter_content(1024): - file.write(chunk) - file_paths.append((camera_id, file_path, location)) - - # Yield TrafficImage objects for each image retrieved - yield [ - TrafficImage( - image=path, - camera_id=camera_id, - latitude=location[0], - longitude=location[1], - ) - for camera_id, path, location in file_paths - ] - - # clean up traffic images - rmtree(save_dir) - else: - raise Exception( - f"Failed to retrieve images from API. Status code: {response.status_code}" - ) + # clean up traffic images + rmtree(image_dir) @pytest.mark.integration diff --git a/pipeline/test_rating_validator.py b/pipeline/test_rating_validator.py index c960bc5..3a2f46c 100644 --- a/pipeline/test_rating_validator.py +++ b/pipeline/test_rating_validator.py @@ -1,20 +1,6 @@ import pytest from rating_validator import RatingValidator -from TrafficImage import TrafficImage - - -# Fixture for TrafficImage instance -@pytest.fixture -def traffic_image(): - return TrafficImage( - image="some_image_url", - processed=True, - congestion_rating=0.5, - camera_id="camera_123", - longitude=103.851959, - latitude=1.290270, - ) # Fixture for RatingValidator instance @@ -34,9 +20,7 @@ def test_validate_with_valid_data(validator): def test_validate_with_no_traffic_image(): # Test without a TrafficImage object validator = RatingValidator(TrafficImageObject=None) - with pytest.raises( - Exception, match="No Traffic Image Parsed Through Validatior!!!" - ): + with pytest.raises(Exception, match="No Traffic Image Parsed Through Validator!!!"): validator.validate() diff --git a/pipeline/timetools.py b/pipeline/timetools.py new file mode 100644 index 0000000..4ab7d44 --- /dev/null +++ b/pipeline/timetools.py @@ -0,0 +1,14 @@ +# +# Flowmotion +# Pipeline +# Time Utitilties +# + + +from datetime import datetime +from zoneinfo import ZoneInfo + + +def datetime_sgt() -> datetime: + """Get the current timestamp in the SGT timezone.""" + return datetime.now(ZoneInfo("Asia/Singapore"))