From 2c56425bd702977d197475522967f1b2264d141c Mon Sep 17 00:00:00 2001 From: Zuhaitz Beloki Leitza Date: Fri, 20 Sep 2024 13:49:00 +0200 Subject: [PATCH 1/2] Macro averaging mode added for computing the F0.5 score. --- errant/commands/compare_m2.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/errant/commands/compare_m2.py b/errant/commands/compare_m2.py index 2e477d0..d627976 100644 --- a/errant/commands/compare_m2.py +++ b/errant/commands/compare_m2.py @@ -54,6 +54,12 @@ def parse_args(): help="Value of beta in F-score. (default: 0.5)", default=0.5, type=float) + parser.add_argument( + "--f_average", + help="Compute the F score using 'micro' or 'macro' averaging", + default="micro", + choices=["micro", "macro"] + ) parser.add_argument( "-v", "--verbose", @@ -322,6 +328,13 @@ def computeFScore(tp, fp, fn, beta): f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0 return round(p, 4), round(r, 4), round(f, 4) +def computeMacroFScore(best_cats, beta): + class_scores = [ computeFScore(tp, fp, fn, beta) for tp, fp, fn in best_cats.values() ] + p = sum([ p for p, _, _ in class_scores ]) / len(class_scores) + r = sum([ r for _, r, _ in class_scores ]) / len(class_scores) + f = sum([ f for _, _, f in class_scores ]) / len(class_scores) + return round(p, 4), round(r, 4), round(f, 4) + # Input 1-2: Two error category dicts. Key is cat, value is list of TP, FP, FN. # Output: The dictionaries combined with cumulative TP, FP, FN. def merge_dict(dict1, dict2): @@ -382,13 +395,20 @@ def print_results(best, best_cats, args): print(cat.ljust(14), str(cnts[0]).ljust(8), str(cnts[1]).ljust(8), str(cnts[2]).ljust(8), str(cat_p).ljust(8), str(cat_r).ljust(8), cat_f) + # Compute F-score + if args.f_average == 'micro': + f_score = computeFScore(best["tp"], best["fp"], best["fn"], args.beta) + else: # args.f_average == 'macro' + f_score = computeMacroFScore(best_cats, args.beta) + # Print the overall results. print("") - print('{:=^46}'.format(title)) - print("\t".join(["TP", "FP", "FN", "Prec", "Rec", "F"+str(args.beta)])) + print('{:=^{width}}'.format(title, width=(46 if args.f_average=='micro' else 52))) + print("\t".join(["TP", "FP", "FN", "Prec", "Rec", + "F"+str(args.beta)+(" (macro)" if args.f_average=='macro' else "")])) print("\t".join(map(str, [best["tp"], best["fp"], - best["fn"]]+list(computeFScore(best["tp"], best["fp"], best["fn"], args.beta))))) - print('{:=^46}'.format("")) + best["fn"]]+list(f_score)))) + print('{:=^{width}}'.format("", width=(46 if args.f_average=='micro' else 52))) print("") def print_table(table): From 943b5e703302ca6abe0425596b7142c5b1bb2cc6 Mon Sep 17 00:00:00 2001 From: Zuhaitz Beloki Leitza Date: Fri, 20 Sep 2024 15:43:28 +0200 Subject: [PATCH 2/2] Macro-averaging is also considered when finding the best hyp-ref combination --- errant/commands/compare_m2.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/errant/commands/compare_m2.py b/errant/commands/compare_m2.py index d627976..7440732 100644 --- a/errant/commands/compare_m2.py +++ b/errant/commands/compare_m2.py @@ -26,7 +26,7 @@ def main(): original_sentence = sent[0][2:].split("\nA")[0] # Evaluate edits and get best TP, FP, FN hyp+ref combo. count_dict, cat_dict = evaluate_edits( - hyp_dict, ref_dict, best_dict, sent_id, original_sentence, args) + hyp_dict, ref_dict, best_dict, best_cats, sent_id, original_sentence, args) # Merge these dicts with best_dict and best_cats best_dict += Counter(count_dict) best_cats = merge_dict(best_cats, cat_dict) @@ -202,11 +202,12 @@ def process_edits(edits, args): # Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits. # Input 2: A ref dict; key is coder_id, value is dict of processed ref edits. # Input 3: A dictionary of the best corpus level TP, FP and FN counts so far. -# Input 4: Sentence ID (for verbose output only) -# Input 5: Command line args +# Input 4: A dictionary of the best corpus level TP, FP and FN counts grouped by categories. +# Input 5: Sentence ID (for verbose output only) +# Input 6: Command line args # Output 1: A dict of the best corpus level TP, FP and FN for the input sentence. # Output 2: The corresponding error type dict for the above dict. -def evaluate_edits(hyp_dict, ref_dict, best, sent_id, original_sentence, args): +def evaluate_edits(hyp_dict, ref_dict, best, best_cats, sent_id, original_sentence, args): # Verbose output: display the original sentence if args.verbose: print('{:-^40}'.format("")) @@ -223,8 +224,20 @@ def evaluate_edits(hyp_dict, ref_dict, best, sent_id, original_sentence, args): # Compute the local sentence scores (for verbose output only) loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta) # Compute the global sentence scores - p, r, f = computeFScore( - tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta) + if args.f_average == 'micro': + p, r, f = computeFScore( + tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta) + else: # args.f_average == 'macro' + # combine best_cats and current cat_dict to get the global cat_dict + tmp_cats = {} + for cat in (best_cats.keys() | cat_dict.keys()): + if cat not in best_cats: + tmp_cats[cat] = cat_dict[cat] + elif cat not in cat_dict: + tmp_cats[cat] = best_cats[cat] + else: + tmp_cats[cat] = [ x1+x2 for x1, x2 in zip(best_cats[cat], cat_dict[cat]) ] + p, r, f = computeMacroFScore(tmp_cats, args.beta) # Save the scores if they are better in terms of: # 1. Higher F-score # 2. Same F-score, higher TP @@ -328,8 +341,8 @@ def computeFScore(tp, fp, fn, beta): f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0 return round(p, 4), round(r, 4), round(f, 4) -def computeMacroFScore(best_cats, beta): - class_scores = [ computeFScore(tp, fp, fn, beta) for tp, fp, fn in best_cats.values() ] +def computeMacroFScore(cats, beta): + class_scores = [ computeFScore(tp, fp, fn, beta) for tp, fp, fn in cats.values() ] p = sum([ p for p, _, _ in class_scores ]) / len(class_scores) r = sum([ r for _, r, _ in class_scores ]) / len(class_scores) f = sum([ f for _, _, f in class_scores ]) / len(class_scores)