Skip to content

Commit 851bede

Browse files
authored
Merge pull request #18 from armingh2000/fix/load-score
Fix/load score
2 parents 67a0f28 + 81d9c60 commit 851bede

File tree

6 files changed

+95
-40
lines changed

6 files changed

+95
-40
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ All notable changes to this project will be documented in this file.
6161

6262
- Renamed test_scorer to test_fact_scorer.
6363

64+
## v 1.0.1 - 2024-04-14
65+
66+
- Fix score calculation when loading from dumped data.
67+
- Add tests for the fix.
68+
- Remove unnecessary code.
69+
6470
<!--
6571
### Added
6672

FactScoreLite/atomic_facts.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -186,19 +186,3 @@ def fix_sentence_splitter(self, sentences: list, initials: list) -> list:
186186
results.append(sent)
187187

188188
return results
189-
190-
191-
if __name__ == "__main__":
192-
generator = AtomicFactGenerator()
193-
text = """
194-
To winterize your battery and prevent damage:
195-
196-
1. **For the Li-ion battery**:
197-
- Avoid storing the vehicle in temperatures below -13°F (-25°C) for more than seven days to prevent the Li-ion battery from freezing.
198-
- Move the vehicle to a warm location if the outside temperature is -13°F (-25°C) or below, as it may freeze and be unable to charge or power the vehicle.
199-
200-
2. **For the 12-volt battery**:
201-
- Ensure it is fully charged during extremely cold weather conditions to prevent the battery fluid from freezing and possibly causing damage to the battery.
202-
""".strip()
203-
204-
print(generator.run(text))

FactScoreLite/fact_scorer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_score(self, facts: list, knowledge_source: str) -> list:
3434

3535
prompt += "\n\n"
3636

37-
prompt += f"Input:\n{atom} True or False?\nOutput:\n"
37+
prompt += f"Input: {atom} True or False?\nOutput:\n"
3838

3939
output = self.openai_agent.generate(prompt)
4040

FactScoreLite/factscore.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from . import FactScorer, AtomicFactGenerator
33
from .state_handler import StateHandler
44
from . import configs
5+
from tqdm import tqdm
56

67

78
class FactScore:
@@ -25,9 +26,11 @@ def get_facts(self, generations: list) -> list:
2526
list: A list of generation-facts pairs dictionaries.
2627
"""
2728

29+
print("Extracting facts from generations...")
30+
2831
generation_facts_pairs = self.facts_handler.load()
2932

30-
for generation in generations[len(generation_facts_pairs) :]:
33+
for generation in tqdm(generations[len(generation_facts_pairs) :]):
3134
atomic_facts_of_generation = self.atomic_fact_generator.run(generation)
3235
atomic_facts_of_generation = [
3336
fact
@@ -48,6 +51,30 @@ def get_facts(self, generations: list) -> list:
4851

4952
return generation_facts_pairs
5053

54+
def calculate_score(self, decision: list) -> tuple:
55+
"""
56+
Calculates the score of a generation based on whether its facts are supported by the knowledge source.
57+
58+
Args:
59+
decision (list): A list containing dictionaries of {output, is_supported, fact} for each fact of a generation.
60+
61+
Returns:
62+
tuple: A tuple containing the score and the original score (without applying gamma penalty).
63+
"""
64+
65+
score = np.mean([d["is_supported"] for d in decision])
66+
init_score = score
67+
68+
if self.gamma:
69+
penalty = (
70+
1.0
71+
if len(decision) >= self.gamma
72+
else np.exp(1 - self.gamma / len(decision))
73+
)
74+
score = penalty * score
75+
76+
return score, init_score
77+
5178
def get_decisions(
5279
self, generation_facts_pairs: list, knowledge_sources: list
5380
) -> list:
@@ -66,31 +93,27 @@ def get_decisions(
6693
and initial scores (original score without applying gamma penalty).
6794
"""
6895

96+
print("Generating decisions...")
97+
6998
decisions = self.decisions_handler.load()
7099
scores = []
71100
init_scores = []
72101

73-
for entry in generation_facts_pairs[len(decisions) :]:
74-
generation, facts = entry["generation"], entry["facts"]
75-
score = None
76-
77-
if facts:
78-
decision = self.fact_scorer.get_score(facts, knowledge_sources)
79-
score = np.mean([d["is_supported"] for d in decision])
102+
for enrty in decisions:
103+
score, init_score = self.calculate_score(enrty["decision"])
104+
init_scores.append(init_score)
105+
scores.append(score)
80106

81-
if self.gamma:
82-
init_scores.append(score)
83-
penalty = (
84-
1.0
85-
if len(facts) > self.gamma
86-
else np.exp(1 - self.gamma / len(facts))
87-
)
88-
score = penalty * score
107+
for entry in tqdm(generation_facts_pairs[len(decisions) :]):
108+
generation, facts = entry["generation"], entry["facts"]
89109

90-
decisions.append({"generation": generation, "decision": decision})
91-
self.decisions_handler.save(decisions)
110+
decision = self.fact_scorer.get_score(facts, knowledge_sources)
111+
score, init_score = self.calculate_score(decision)
92112

113+
init_scores.append(init_score)
93114
scores.append(score)
115+
decisions.append({"generation": generation, "decision": decision})
116+
self.decisions_handler.save(decisions)
94117

95118
assert len(facts) == len(
96119
decision
@@ -100,7 +123,7 @@ def get_decisions(
100123
generation_facts_pairs
101124
), "Number of decisions and generation-facts pairs should be the same."
102125

103-
return scores, decisions, init_scores
126+
return scores, init_scores
104127

105128
def get_factscore(
106129
self,
@@ -124,6 +147,6 @@ def get_factscore(
124147
), "`generations` and `knowledge_sources` should have the same length."
125148

126149
facts = self.get_facts(generations)
127-
scores, decisions, init_scores = self.get_decisions(facts, knowledge_sources)
150+
scores, init_scores = self.get_decisions(facts, knowledge_sources)
128151

129152
return np.mean(scores), np.mean(init_scores)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = FactScoreLite
3-
version = 1.0.0
3+
version = 1.0.1
44
author = armingh2000
55
author_email =
66
license = MIT

tests/test_factscore.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_get_decisions_with_valid_input(
6565
{"is_supported": True},
6666
{"is_supported": False},
6767
]
68-
scores, decisions, init_scores = fact_score.get_decisions(
68+
scores, init_scores = fact_score.get_decisions(
6969
generation_facts_pairs, knowledge_sources
7070
)
7171

@@ -85,7 +85,12 @@ def test_get_factscore_from_saved_states(
8585
):
8686
mock_state_handler.load.side_effect = [
8787
[{"generation": "gen1", "facts": ["fact1", "fact2"]}],
88-
[{"generation": "gen1", "facts": [{"fact": "fact1", "is_supported": True}]}],
88+
[
89+
{
90+
"generation": "gen1",
91+
"decision": [{"fact": "fact1", "is_supported": True, "output": "True"}],
92+
}
93+
],
8994
] # First for facts, second for decisions
9095
generations = ["generation1", "generation2"]
9196
knowledge_sources = ["source1", "source2"]
@@ -97,3 +102,40 @@ def test_get_factscore_from_saved_states(
97102
avg_score, avg_init_score = fact_score.get_factscore(generations, knowledge_sources)
98103
assert isinstance(avg_score, float)
99104
assert isinstance(avg_init_score, float)
105+
106+
107+
@pytest.mark.parametrize(
108+
"decision, expected_score",
109+
[
110+
([{"is_supported": True} for _ in range(10)], 1.0),
111+
([{"is_supported": True} for _ in range(5)], 1.0),
112+
([{"is_supported": False} for _ in range(10)], 0.0),
113+
(
114+
[
115+
{"is_supported": True} if i % 2 == 0 else {"is_supported": False}
116+
for i in range(10)
117+
],
118+
0.5,
119+
),
120+
],
121+
)
122+
def test_calculate_score_various_decisions(fact_score, decision, expected_score):
123+
score, init_score = fact_score.calculate_score(decision)
124+
assert (
125+
init_score == expected_score
126+
), "Initial score should match expected mean of decisions"
127+
if len(decision) >= fact_score.gamma:
128+
assert (
129+
score == expected_score
130+
), "Score should not be penalized when decision count exceeds gamma"
131+
else:
132+
assert (
133+
score != expected_score
134+
), "Score should be penalized when decision count is below gamma"
135+
136+
137+
def test_gamma_zero(fact_score):
138+
fact_score.gamma = 0 # Setting gamma to zero
139+
decision = [{"is_supported": True} for _ in range(5)]
140+
score, init_score = fact_score.calculate_score(decision)
141+
assert score == init_score, "No penalty should apply when gamma is zero"

0 commit comments

Comments
 (0)