Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions errant/commands/compare_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def main():
original_sentence = sent[0][2:].split("\nA")[0]
# Evaluate edits and get best TP, FP, FN hyp+ref combo.
count_dict, cat_dict = evaluate_edits(
hyp_dict, ref_dict, best_dict, sent_id, original_sentence, args)
hyp_dict, ref_dict, best_dict, best_cats, sent_id, original_sentence, args)
# Merge these dicts with best_dict and best_cats
best_dict += Counter(count_dict)
best_cats = merge_dict(best_cats, cat_dict)
Expand Down Expand Up @@ -54,6 +54,12 @@ def parse_args():
help="Value of beta in F-score. (default: 0.5)",
default=0.5,
type=float)
parser.add_argument(
"--f_average",
help="Compute the F score using 'micro' or 'macro' averaging",
default="micro",
choices=["micro", "macro"]
)
parser.add_argument(
"-v",
"--verbose",
Expand Down Expand Up @@ -196,11 +202,12 @@ def process_edits(edits, args):
# Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
# Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
# Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
# Input 4: Sentence ID (for verbose output only)
# Input 5: Command line args
# Input 4: A dictionary of the best corpus level TP, FP and FN counts grouped by categories.
# Input 5: Sentence ID (for verbose output only)
# Input 6: Command line args
# Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
# Output 2: The corresponding error type dict for the above dict.
def evaluate_edits(hyp_dict, ref_dict, best, sent_id, original_sentence, args):
def evaluate_edits(hyp_dict, ref_dict, best, best_cats, sent_id, original_sentence, args):
# Verbose output: display the original sentence
if args.verbose:
print('{:-^40}'.format(""))
Expand All @@ -217,8 +224,20 @@ def evaluate_edits(hyp_dict, ref_dict, best, sent_id, original_sentence, args):
# Compute the local sentence scores (for verbose output only)
loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
# Compute the global sentence scores
p, r, f = computeFScore(
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
if args.f_average == 'micro':
p, r, f = computeFScore(
tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
else: # args.f_average == 'macro'
# combine best_cats and current cat_dict to get the global cat_dict
tmp_cats = {}
for cat in (best_cats.keys() | cat_dict.keys()):
if cat not in best_cats:
tmp_cats[cat] = cat_dict[cat]
elif cat not in cat_dict:
tmp_cats[cat] = best_cats[cat]
else:
tmp_cats[cat] = [ x1+x2 for x1, x2 in zip(best_cats[cat], cat_dict[cat]) ]
p, r, f = computeMacroFScore(tmp_cats, args.beta)
# Save the scores if they are better in terms of:
# 1. Higher F-score
# 2. Same F-score, higher TP
Expand Down Expand Up @@ -322,6 +341,13 @@ def computeFScore(tp, fp, fn, beta):
f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
return round(p, 4), round(r, 4), round(f, 4)

def computeMacroFScore(cats, beta):
class_scores = [ computeFScore(tp, fp, fn, beta) for tp, fp, fn in cats.values() ]
p = sum([ p for p, _, _ in class_scores ]) / len(class_scores)
r = sum([ r for _, r, _ in class_scores ]) / len(class_scores)
f = sum([ f for _, _, f in class_scores ]) / len(class_scores)
return round(p, 4), round(r, 4), round(f, 4)

# Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN.
# Output: The dictionaries combined with cumulative TP, FP, FN.
def merge_dict(dict1, dict2):
Expand Down Expand Up @@ -382,13 +408,20 @@ def print_results(best, best_cats, args):
print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8),
str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f)

# Compute F-score
if args.f_average == 'micro':
f_score = computeFScore(best["tp"], best["fp"], best["fn"], args.beta)
else: # args.f_average == 'macro'
f_score = computeMacroFScore(best_cats, args.beta)

# Print the overall results.
print("")
print('{:=^46}'.format(title))
print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)]))
print('{:=^{width}}'.format(title, width=(46 if args.f_average=='micro' else 52)))
print("\t".join(["TP", "FP", "FN", "Prec", "Rec",
"F"+str(args.beta)+(" (macro)" if args.f_average=='macro' else "")]))
print("\t".join(map(str, [best["tp"], best["fp"],
best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta)))))
print('{:=^46}'.format(""))
best["fn"]]+list(f_score))))
print('{:=^{width}}'.format("", width=(46 if args.f_average=='micro' else 52)))
print("")

def print_table(table):
Expand Down