Skip to content

Forces the datatype of a column after an upload #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions controller/notification/notification_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@
"page": enums.Pages.SETTINGS.value,
"docs": enums.DOCS.UPLOADING_DATA.value,
},
enums.NotificationType.IMPORT_CONVERSION_ERROR.value: {
"message_template": "Data type count't be forced (@@arg@@).",
"title": "Data import",
"level": enums.Notification.ERROR.value,
"page": enums.Pages.SETTINGS.value,
"docs": enums.DOCS.UPLOADING_DATA.value,
},
enums.NotificationType.INVALID_FILE_TYPE.value: {
"message_template": "File type @@arg@@ is currently not supported.",
"title": "Data import",
Expand Down
73 changes: 64 additions & 9 deletions controller/transfer/record_transfer_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from typing import Dict, Any, Optional, Tuple, List
from typing import Dict, Any, Optional, Tuple, List, Callable

import pandas as pd

Expand All @@ -13,24 +13,24 @@
labeling_task_label,
labeling_task,
organization,
project,
record,
record_label_association,
upload_task,
)
from controller.user import manager as user_manager

from controller.upload_task import manager as upload_task_manager
from controller.tokenization import manager as token_manager
from util import file, security
from submodules.s3 import controller as s3
from submodules.model import enums, events, UploadTask, Attribute
from submodules.model import enums, UploadTask, Attribute
from util import category
from util import notification
from controller.transfer.util import convert_to_record_dict
import os


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
import os


def import_records_and_rlas(
Expand All @@ -43,6 +43,7 @@ def import_records_and_rlas(
CHUNK_SIZE = 500
chunks = [data[x : x + CHUNK_SIZE] for x in range(0, len(data), CHUNK_SIZE)]
chunks_count = len(chunks)
attribute_lookup = None
for idx, chunk in enumerate(chunks):
if upload_task is not None:
logger.debug(
Expand All @@ -61,7 +62,32 @@ def import_records_and_rlas(
if idx == 0:
create_attributes_and_get_text_attributes(project_id, records_data)
primary_keys = attribute.get_primary_keys(project_id)
if attribute_lookup is None:
existing_attributes = attribute.get_all(project_id)
attribute_lookup = {
attribute.name: __match_data_type_to_function(attribute.data_type)
for attribute in existing_attributes
}
attribute_lookup = {
k: v for k, v in attribute_lookup.items() if v is not None
}
try:
__force_data_type_for_attributes(records_data, attribute_lookup)
except Exception as e:
logger.error(f"Error while forcing data type for attributes: {e}")
if upload_task is not None:

upload_task_manager.update_task(
project_id, upload_task.id, state=enums.UploadStates.ERROR.value
)

notification.create_notification(
enums.NotificationType.IMPORT_CONVERSION_ERROR,
user_id,
project_id,
str(e),
)
raise e
import_labeling_tasks_and_labels_pipeline(
project_id=project_id, tasks_data=tasks_data
)
Expand All @@ -81,6 +107,38 @@ def import_records_and_rlas(
)


def __force_data_type_for_attributes(
records_data: List[Dict[str, Any]],
attribute_lookup: Dict[str, Callable],
) -> None:
if len(records_data) == 0:
return
if len(attribute_lookup) == 0:
return
for record_data in records_data:
for key, value in record_data.items():
if (
key in attribute_lookup
and value is not None
and not isinstance(value, attribute_lookup[key])
):
record_data[key] = attribute_lookup[key](value)


def __match_data_type_to_function(data_type: str) -> Callable:
if data_type == enums.DataTypes.INTEGER.value:
return int
elif data_type == enums.DataTypes.FLOAT.value:
return float
elif data_type == enums.DataTypes.BOOLEAN.value:
return bool
elif (
data_type == enums.DataTypes.TEXT.value
or data_type == enums.DataTypes.CATEGORY.value
):
return str


def download_file(project_id: str, task: UploadTask) -> str:
# TODO is copied from import_file and can be refactored because atm its duplicated code
upload_task_manager.update_task(
Expand Down Expand Up @@ -129,7 +187,6 @@ def import_file(project_id: str, upload_task: UploadTask) -> None:
project_id,
column_mappings,
)
number_records = len(data)
import_records_and_rlas(
project_id, upload_task.user_id, data, upload_task, record_category
)
Expand All @@ -141,8 +198,6 @@ def import_file(project_id: str, upload_task: UploadTask) -> None:
general.commit()
upload_task_manager.update_upload_task_to_finished(upload_task)

user = user_manager.get_or_create_user(upload_task.user_id)
project_item = project.get(project_id)
general.commit()


Expand Down Expand Up @@ -321,6 +376,6 @@ def create_attributes_and_get_text_attributes(
text_attributes.append(attribute_item)
general.flush()
if created_something:
notification.send_organization_update(project_id, f"attributes_updated")
notification.send_organization_update(project_id, "attributes_updated")

return text_attributes
2 changes: 1 addition & 1 deletion submodules/model
Loading