@@ -131,7 +131,7 @@ class ImporterPass(Enum):
131131 IMPORT_DATA = 2
132132
133133
134- class PreImportResult (Enum ):
134+ class PreImportRecordResult (Enum ):
135135 """Pre Import Response."""
136136
137137 SKIP_RECORD = False
@@ -149,7 +149,8 @@ class SourceFieldSource(Enum):
149149 IDENTIFIER = auto () # Fields used as identifiers
150150
151151
152- PreImport = Callable [[RecordData , ImporterPass ], PreImportResult ]
152+ PreImportRecord = Callable [[RecordData , ImporterPass ], PreImportRecordResult ]
153+ PostImportRecord = Callable [[RecordData , DiffSyncBaseModel ], None ]
153154SourceDataGenerator = Callable [[], Iterable [SourceRecord ]]
154155SourceFieldImporter = Callable [[RecordData , DiffSyncBaseModel ], None ]
155156GetPkFromData = Callable [[RecordData ], Uid ]
@@ -205,7 +206,8 @@ def configure_model(
205206 default_reference : Optional [RecordData ] = None ,
206207 flags : Optional [DiffSyncModelFlags ] = None ,
207208 nautobot_flags : Optional [DiffSyncModelFlags ] = None ,
208- pre_import : Optional [PreImport ] = None ,
209+ pre_import_record : Optional [PreImportRecord ] = None ,
210+ post_import_record : Optional [PostImportRecord ] = None ,
209211 disable_related_reference : Optional [bool ] = None ,
210212 forward_references : Optional [ForwardReferences ] = None ,
211213 fill_dummy_data : Optional [FillDummyData ] = None ,
@@ -257,8 +259,10 @@ def configure_model(
257259 wrapper .flags = flags
258260 if nautobot_flags is not None :
259261 wrapper .nautobot .flags = nautobot_flags
260- if pre_import :
261- wrapper .pre_import = pre_import
262+ if pre_import_record :
263+ wrapper .pre_import_record = pre_import_record
264+ if post_import_record :
265+ wrapper .post_import_record = post_import_record
262266 if disable_related_reference is not None :
263267 wrapper .disable_related_reference = disable_related_reference
264268 if forward_references :
@@ -349,7 +353,7 @@ def get_nautobot_content_type_uid(self, content_type: ContentTypeValue) -> int:
349353 def load (self ) -> None :
350354 """Load data from the source."""
351355 self .import_data ()
352- self .post_import ()
356+ self .post_load ()
353357
354358 def import_data (self ) -> None :
355359 """Import data from the source."""
@@ -379,9 +383,9 @@ def import_data(self) -> None:
379383 for content_type , data in get_source_data ():
380384 self .wrappers [content_type ].second_pass (data )
381385
382- def post_import (self ) -> None :
386+ def post_load (self ) -> None :
383387 """Post import processing."""
384- while any (wrapper .post_import () for wrapper in self .wrappers .values ()):
388+ while any (wrapper .post_process_references () for wrapper in self .wrappers .values ()):
385389 pass
386390
387391 for nautobot_wrapper in self .get_imported_nautobot_wrappers ():
@@ -458,7 +462,8 @@ def __init__(self, adapter: SourceAdapter, content_type: ContentTypeStr, nautobo
458462
459463 # Source fields defintions
460464 self .fields : OrderedDict [FieldName , SourceField ] = OrderedDict ()
461- self .pre_import : Optional [PreImport ] = None
465+ self .pre_import_record : Optional [PreImportRecord ] = None
466+ self .post_import_record : Optional [PostImportRecord ] = None
462467
463468 if self .disable_reason :
464469 self .adapter .logger .debug ("Created disabled %s" , self )
@@ -504,8 +509,8 @@ def cache_record_uids(self, source: RecordData, nautobot_uid: Optional[Uid] = No
504509
505510 def first_pass (self , data : RecordData ) -> None :
506511 """Firts pass of data import."""
507- if self .pre_import :
508- if self .pre_import (data , ImporterPass .DEFINE_STRUCTURE ) != PreImportResult .USE_RECORD :
512+ if self .pre_import_record :
513+ if self .pre_import_record (data , ImporterPass .DEFINE_STRUCTURE ) != PreImportRecordResult .USE_RECORD :
509514 self .stats .first_pass_skipped += 1
510515 return
511516
@@ -522,14 +527,17 @@ def second_pass(self, data: RecordData) -> None:
522527 if self .disable_reason :
523528 return
524529
525- if self .pre_import :
526- if self .pre_import (data , ImporterPass .IMPORT_DATA ) != PreImportResult .USE_RECORD :
530+ if self .pre_import_record :
531+ if self .pre_import_record (data , ImporterPass .IMPORT_DATA ) != PreImportRecordResult .USE_RECORD :
527532 self .stats .second_pass_skipped += 1
528533 return
529534
530535 self .stats .second_pass_used += 1
531536
532- self .import_record (data )
537+ target = self .import_record (data )
538+
539+ if self .post_import_record :
540+ self .post_import_record (data , target )
533541
534542 def get_summary (self , content_type_id ) -> SourceModelSummary :
535543 """Get a summary of the model."""
@@ -544,7 +552,8 @@ def get_summary(self, content_type_id) -> SourceModelSummary:
544552 identifiers = self .identifiers ,
545553 disable_related_reference = self .disable_related_reference ,
546554 forward_references = self .forward_references and self .forward_references .__name__ or None ,
547- pre_import = self .pre_import and self .pre_import .__name__ or None ,
555+ pre_import = self .pre_import_record and self .pre_import_record .__name__ or None ,
556+ post_import = self .post_import_record and self .post_import_record .__name__ or None ,
548557 fields = sorted (fields , key = lambda field : field .name ),
549558 flags = str (self .flags ),
550559 default_reference_uid = serialize_to_summary (self .default_reference_uid ),
@@ -817,7 +826,7 @@ def set_default_reference(self, data: RecordData) -> None:
817826 """Set the default reference to this model."""
818827 self .default_reference_uid = self .cache_record (data )
819828
820- def post_import (self ) -> bool :
829+ def post_process_references (self ) -> bool :
821830 """Post import processing.
822831
823832 Assigns referenced content_types to referencing instances.
0 commit comments