Skip to content

Commit 3ad9655

Browse files
authored
Merge pull request #5 from maliedvp/new_similarities
New similarities
2 parents ddf661f + 99e034c commit 3ad9655

File tree

4 files changed

+79
-5
lines changed

4 files changed

+79
-5
lines changed

src/neer_match/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
Neural-symbolic Entity Reasoning and Matching.
44
"""
55

6-
__version__ = '0.7.34'
6+
__version__ = '0.7.35'

src/neer_match/matching_model.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import pandas as pd
1414
import tensorflow as tf
1515
import typing
16+
import pickle
17+
from pathlib import Path
1618

1719

1820
def _suggest(
@@ -222,6 +224,32 @@ def suggest(
222224
"""
223225
return _suggest(self, left, right, count, batch_size=batch_size, **kwargs)
224226

227+
def save(self, target_directory: Path, name: str, include_optimizer: bool = True) -> None:
228+
"""Save the deep learning model to disk.
229+
230+
Saves the model architecture, weights, and optimizer state (optional),
231+
along with the similarity map.
232+
233+
Args:
234+
target_directory: The directory where the model will be saved.
235+
name: The name of the model (used as a subdirectory).
236+
include_optimizer: Whether to save the optimizer state.
237+
"""
238+
# Ensure target_directory is a Path object
239+
target_directory = Path(target_directory) / name / 'model'
240+
241+
# Ensure the directory exists
242+
target_directory.mkdir(parents=True, exist_ok=True)
243+
244+
# Save the model architecture and weights
245+
super().save(target_directory / "model.h5", include_optimizer=include_optimizer)
246+
247+
# Save the similarity map
248+
with open(target_directory / "similarity_map.pkl", "wb") as f:
249+
pickle.dump(self.similarity_map, f)
250+
251+
print(f"Model successfully saved to {target_directory}")
252+
225253
@property
226254
def similarity_map(self) -> SimilarityMap:
227255
"""Similarity Map of the Model."""
@@ -476,6 +504,7 @@ def fit(
476504
right: pd.DataFrame,
477505
matches: pd.DataFrame,
478506
epochs: int,
507+
mismatch_share: float = 0.1,
479508
satisfiability_weight: float = 1.0,
480509
verbose: int = 1,
481510
log_mod_n: int = 1,
@@ -496,6 +525,7 @@ def fit(
496525
right: The right data frame.
497526
matches: The matches data frame.
498527
epochs: The number of epochs to train.
528+
mismatch_share: The mismatch share.
499529
satisfiability_weight: The weight of the satisfiability loss.
500530
verbose: The verbosity level.
501531
log_mod_n: The log modulo.
@@ -512,7 +542,12 @@ def fit(
512542
# The remaining arguments are validated in the DataGenerator
513543

514544
data_generator = DataGenerator(
515-
self.record_pair_network.similarity_map, left, right, matches, **kwargs
545+
self.record_pair_network.similarity_map,
546+
left,
547+
right,
548+
matches,
549+
mismatch_share=mismatch_share,
550+
**kwargs
516551
)
517552

518553
axioms = self._make_axioms(data_generator)
@@ -529,6 +564,7 @@ def evaluate(
529564
right: pd.DataFrame,
530565
matches: pd.DataFrame,
531566
batch_size: int = 16,
567+
mismatch_share: float = 1.0,
532568
satisfiability_weight: float = 1.0,
533569
) -> dict:
534570
"""Evaluate the model.
@@ -542,14 +578,15 @@ def evaluate(
542578
right: The right data frame.
543579
matches: The matches data frame.
544580
batch_size: Batch size.
581+
mismatch_share: The mismatch share.
545582
satisfiability_weight: The weight of the satisfiability loss.
546583
"""
547584
data_generator = DataGenerator(
548585
self.record_pair_network.similarity_map,
549586
left,
550587
right,
551588
matches,
552-
mismatch_share=1.0,
589+
mismatch_share=mismatch_share,
553590
batch_size=batch_size,
554591
shuffle=False,
555592
)
@@ -634,6 +671,34 @@ def suggest(
634671
"""
635672
return _suggest(self, left, right, count, batch_size=batch_size)
636673

674+
def save(self, target_directory: Path, name: str) -> None:
675+
"""Save the neural-symbolic model to disk.
676+
677+
Saves the record pair network, similarity map, and optimizer.
678+
679+
Args:
680+
target_directory: The directory where the model will be saved.
681+
name: The name of the model (used as a subdirectory).
682+
"""
683+
# Ensure target_directory is a Path object
684+
target_directory = Path(target_directory) / name / 'model'
685+
686+
# Ensure the directory exists
687+
target_directory.mkdir(parents=True, exist_ok=True)
688+
689+
# Save the record pair network weights
690+
self.record_pair_network.save_weights(target_directory / "record_pair_network.weights.h5")
691+
692+
# Save the similarity map
693+
with open(target_directory / "similarity_map.pkl", "wb") as f:
694+
pickle.dump(self.record_pair_network.similarity_map, f)
695+
696+
# Save the optimizer state
697+
with open(target_directory / "optimizer.pkl", "wb") as f:
698+
pickle.dump(self.optimizer.get_config(), f)
699+
700+
print(f"Model successfully saved to {target_directory}")
701+
637702
@property
638703
def similarity_map(self) -> SimilarityMap:
639704
"""Similarity Map of the Model."""

src/neer_match/similarity_map.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
records of two datasets.
66
"""
77

8-
from rapidfuzz import distance
8+
from rapidfuzz import distance, fuzz
99
import numpy
1010
import typing
1111

@@ -30,6 +30,7 @@ def gaussian(x: typing.Union[float, int], y: typing.Union[float, int]) -> float:
3030
def available_similarities() -> typing.Dict[str, typing.Callable]:
3131
"""Return the list of available similarities."""
3232
return {
33+
"basic_ratio": fuzz.ratio,
3334
"damerau_levenshtein": distance.DamerauLevenshtein.normalized_similarity,
3435
"discrete": discrete,
3536
"euclidean": euclidean,
@@ -41,8 +42,16 @@ def available_similarities() -> typing.Dict[str, typing.Callable]:
4142
"lcsseq": distance.LCSseq.normalized_similarity,
4243
"levenshtein": distance.Levenshtein.normalized_similarity,
4344
"osa": distance.OSA.normalized_similarity,
45+
"partial_ratio": fuzz.partial_ratio,
46+
"partial_ratio_alignment": fuzz.partial_ratio_alignment,
47+
"partial_token_ratio": fuzz.partial_token_ratio,
48+
"partial_token_set_ratio": fuzz.partial_token_set_ratio,
49+
"partial_token_sort_ratio": fuzz.partial_token_sort_ratio,
4450
"postfix": distance.Postfix.normalized_similarity,
4551
"prefix": distance.Prefix.normalized_similarity,
52+
"token_ratio": fuzz.token_ratio,
53+
"token_set_ratio": fuzz.token_set_ratio,
54+
"token_sort_ratio": fuzz.token_sort_ratio,
4655
}
4756

4857

test/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"title": ["jaro_winkler"],
1111
"platform": ["levenshtein", "jaro"],
1212
"year": ["euclidean", "discrete"],
13-
"developer~dev": ["jaro"],
13+
"developer~dev": ["jaro", "token_sort_ratio"],
1414
}
1515

1616
items = [

0 commit comments

Comments
 (0)