diff --git a/controller/notification/notification_data.py b/controller/notification/notification_data.py index 948ee775..e040d2c3 100644 --- a/controller/notification/notification_data.py +++ b/controller/notification/notification_data.py @@ -30,6 +30,13 @@ "page": enums.Pages.SETTINGS.value, "docs": enums.DOCS.UPLOADING_DATA.value, }, + enums.NotificationType.IMPORT_CONVERSION_ERROR.value: { + "message_template": "Data type count't be forced (@@arg@@).", + "title": "Data import", + "level": enums.Notification.ERROR.value, + "page": enums.Pages.SETTINGS.value, + "docs": enums.DOCS.UPLOADING_DATA.value, + }, enums.NotificationType.INVALID_FILE_TYPE.value: { "message_template": "File type @@arg@@ is currently not supported.", "title": "Data import", diff --git a/controller/transfer/record_transfer_manager.py b/controller/transfer/record_transfer_manager.py index a1a00d7e..6eac65f4 100644 --- a/controller/transfer/record_transfer_manager.py +++ b/controller/transfer/record_transfer_manager.py @@ -1,6 +1,6 @@ import json import logging -from typing import Dict, Any, Optional, Tuple, List +from typing import Dict, Any, Optional, Tuple, List, Callable import pandas as pd @@ -13,24 +13,24 @@ labeling_task_label, labeling_task, organization, - project, record, record_label_association, upload_task, ) -from controller.user import manager as user_manager + from controller.upload_task import manager as upload_task_manager from controller.tokenization import manager as token_manager from util import file, security from submodules.s3 import controller as s3 -from submodules.model import enums, events, UploadTask, Attribute +from submodules.model import enums, UploadTask, Attribute from util import category from util import notification from controller.transfer.util import convert_to_record_dict +import os + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -import os def import_records_and_rlas( @@ -43,6 +43,7 @@ def import_records_and_rlas( CHUNK_SIZE = 500 chunks = [data[x : x + CHUNK_SIZE] for x in range(0, len(data), CHUNK_SIZE)] chunks_count = len(chunks) + attribute_lookup = None for idx, chunk in enumerate(chunks): if upload_task is not None: logger.debug( @@ -61,7 +62,32 @@ def import_records_and_rlas( if idx == 0: create_attributes_and_get_text_attributes(project_id, records_data) primary_keys = attribute.get_primary_keys(project_id) + if attribute_lookup is None: + existing_attributes = attribute.get_all(project_id) + attribute_lookup = { + attribute.name: __match_data_type_to_function(attribute.data_type) + for attribute in existing_attributes + } + attribute_lookup = { + k: v for k, v in attribute_lookup.items() if v is not None + } + try: + __force_data_type_for_attributes(records_data, attribute_lookup) + except Exception as e: + logger.error(f"Error while forcing data type for attributes: {e}") + if upload_task is not None: + + upload_task_manager.update_task( + project_id, upload_task.id, state=enums.UploadStates.ERROR.value + ) + notification.create_notification( + enums.NotificationType.IMPORT_CONVERSION_ERROR, + user_id, + project_id, + str(e), + ) + raise e import_labeling_tasks_and_labels_pipeline( project_id=project_id, tasks_data=tasks_data ) @@ -81,6 +107,38 @@ def import_records_and_rlas( ) +def __force_data_type_for_attributes( + records_data: List[Dict[str, Any]], + attribute_lookup: Dict[str, Callable], +) -> None: + if len(records_data) == 0: + return + if len(attribute_lookup) == 0: + return + for record_data in records_data: + for key, value in record_data.items(): + if ( + key in attribute_lookup + and value is not None + and not isinstance(value, attribute_lookup[key]) + ): + record_data[key] = attribute_lookup[key](value) + + +def __match_data_type_to_function(data_type: str) -> Callable: + if data_type == enums.DataTypes.INTEGER.value: + return int + elif data_type == enums.DataTypes.FLOAT.value: + return float + elif data_type == enums.DataTypes.BOOLEAN.value: + return bool + elif ( + data_type == enums.DataTypes.TEXT.value + or data_type == enums.DataTypes.CATEGORY.value + ): + return str + + def download_file(project_id: str, task: UploadTask) -> str: # TODO is copied from import_file and can be refactored because atm its duplicated code upload_task_manager.update_task( @@ -129,7 +187,6 @@ def import_file(project_id: str, upload_task: UploadTask) -> None: project_id, column_mappings, ) - number_records = len(data) import_records_and_rlas( project_id, upload_task.user_id, data, upload_task, record_category ) @@ -141,8 +198,6 @@ def import_file(project_id: str, upload_task: UploadTask) -> None: general.commit() upload_task_manager.update_upload_task_to_finished(upload_task) - user = user_manager.get_or_create_user(upload_task.user_id) - project_item = project.get(project_id) general.commit() @@ -321,6 +376,6 @@ def create_attributes_and_get_text_attributes( text_attributes.append(attribute_item) general.flush() if created_something: - notification.send_organization_update(project_id, f"attributes_updated") + notification.send_organization_update(project_id, "attributes_updated") return text_attributes diff --git a/submodules/model b/submodules/model index 29ec9451..15df3b66 160000 --- a/submodules/model +++ b/submodules/model @@ -1 +1 @@ -Subproject commit 29ec9451308bdffd3b4aacb44ef4e952ac64d49f +Subproject commit 15df3b661f51d892f19f52af874fe5db944c5a4c