13
13
import pandas as pd
14
14
import tensorflow as tf
15
15
import typing
16
+ import pickle
17
+ from pathlib import Path
16
18
17
19
18
20
def _suggest (
@@ -222,6 +224,32 @@ def suggest(
222
224
"""
223
225
return _suggest (self , left , right , count , batch_size = batch_size , ** kwargs )
224
226
227
+ def save (self , target_directory : Path , name : str , include_optimizer : bool = True ) -> None :
228
+ """Save the deep learning model to disk.
229
+
230
+ Saves the model architecture, weights, and optimizer state (optional),
231
+ along with the similarity map.
232
+
233
+ Args:
234
+ target_directory: The directory where the model will be saved.
235
+ name: The name of the model (used as a subdirectory).
236
+ include_optimizer: Whether to save the optimizer state.
237
+ """
238
+ # Ensure target_directory is a Path object
239
+ target_directory = Path (target_directory ) / name / 'model'
240
+
241
+ # Ensure the directory exists
242
+ target_directory .mkdir (parents = True , exist_ok = True )
243
+
244
+ # Save the model architecture and weights
245
+ super ().save (target_directory / "model.h5" , include_optimizer = include_optimizer )
246
+
247
+ # Save the similarity map
248
+ with open (target_directory / "similarity_map.pkl" , "wb" ) as f :
249
+ pickle .dump (self .similarity_map , f )
250
+
251
+ print (f"Model successfully saved to { target_directory } " )
252
+
225
253
@property
226
254
def similarity_map (self ) -> SimilarityMap :
227
255
"""Similarity Map of the Model."""
@@ -476,6 +504,7 @@ def fit(
476
504
right : pd .DataFrame ,
477
505
matches : pd .DataFrame ,
478
506
epochs : int ,
507
+ mismatch_share : float = 0.1 ,
479
508
satisfiability_weight : float = 1.0 ,
480
509
verbose : int = 1 ,
481
510
log_mod_n : int = 1 ,
@@ -496,6 +525,7 @@ def fit(
496
525
right: The right data frame.
497
526
matches: The matches data frame.
498
527
epochs: The number of epochs to train.
528
+ mismatch_share: The mismatch share.
499
529
satisfiability_weight: The weight of the satisfiability loss.
500
530
verbose: The verbosity level.
501
531
log_mod_n: The log modulo.
@@ -512,7 +542,12 @@ def fit(
512
542
# The remaining arguments are validated in the DataGenerator
513
543
514
544
data_generator = DataGenerator (
515
- self .record_pair_network .similarity_map , left , right , matches , ** kwargs
545
+ self .record_pair_network .similarity_map ,
546
+ left ,
547
+ right ,
548
+ matches ,
549
+ mismatch_share = mismatch_share ,
550
+ ** kwargs
516
551
)
517
552
518
553
axioms = self ._make_axioms (data_generator )
@@ -529,6 +564,7 @@ def evaluate(
529
564
right : pd .DataFrame ,
530
565
matches : pd .DataFrame ,
531
566
batch_size : int = 16 ,
567
+ mismatch_share : float = 1.0 ,
532
568
satisfiability_weight : float = 1.0 ,
533
569
) -> dict :
534
570
"""Evaluate the model.
@@ -542,14 +578,15 @@ def evaluate(
542
578
right: The right data frame.
543
579
matches: The matches data frame.
544
580
batch_size: Batch size.
581
+ mismatch_share: The mismatch share.
545
582
satisfiability_weight: The weight of the satisfiability loss.
546
583
"""
547
584
data_generator = DataGenerator (
548
585
self .record_pair_network .similarity_map ,
549
586
left ,
550
587
right ,
551
588
matches ,
552
- mismatch_share = 1.0 ,
589
+ mismatch_share = mismatch_share ,
553
590
batch_size = batch_size ,
554
591
shuffle = False ,
555
592
)
@@ -634,6 +671,34 @@ def suggest(
634
671
"""
635
672
return _suggest (self , left , right , count , batch_size = batch_size )
636
673
674
+ def save (self , target_directory : Path , name : str ) -> None :
675
+ """Save the neural-symbolic model to disk.
676
+
677
+ Saves the record pair network, similarity map, and optimizer.
678
+
679
+ Args:
680
+ target_directory: The directory where the model will be saved.
681
+ name: The name of the model (used as a subdirectory).
682
+ """
683
+ # Ensure target_directory is a Path object
684
+ target_directory = Path (target_directory ) / name / 'model'
685
+
686
+ # Ensure the directory exists
687
+ target_directory .mkdir (parents = True , exist_ok = True )
688
+
689
+ # Save the record pair network weights
690
+ self .record_pair_network .save_weights (target_directory / "record_pair_network.weights.h5" )
691
+
692
+ # Save the similarity map
693
+ with open (target_directory / "similarity_map.pkl" , "wb" ) as f :
694
+ pickle .dump (self .record_pair_network .similarity_map , f )
695
+
696
+ # Save the optimizer state
697
+ with open (target_directory / "optimizer.pkl" , "wb" ) as f :
698
+ pickle .dump (self .optimizer .get_config (), f )
699
+
700
+ print (f"Model successfully saved to { target_directory } " )
701
+
637
702
@property
638
703
def similarity_map (self ) -> SimilarityMap :
639
704
"""Similarity Map of the Model."""
0 commit comments