-
Notifications
You must be signed in to change notification settings - Fork 446
AUC chart #2171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
AUC chart #2171
Changes from 15 commits
6fb178f
ff63c42
62be4c6
d30c5d4
a280519
861bde5
49537d1
002fe4d
84ef9cf
925d532
f283823
6b4fc7d
6baeba2
03e15b4
7b0d2cc
a05ebc3
41e42cd
17f3682
f197aec
2a3ca4c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| // Copyright (c) Microsoft Corporation. | ||
| // Licensed under the MIT License. | ||
|
|
||
| export interface IAUCData { | ||
| AUCData: number[][]; | ||
| selectedLabels: string[]; | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| // Copyright (c) Microsoft Corporation. | ||
| // Licensed under the MIT License. | ||
|
|
||
| import { binarizeData, calculatePerClassROCData } from "./calculateAUCData"; | ||
|
|
||
| describe("Test binarizeData", () => { | ||
| it("should binarize numbers", () => { | ||
| const result = binarizeData([1, 3, 4, 0], [0, 1, 2, 3, 4]); | ||
| expect(result).toEqual([ | ||
| [0, 1, 0, 0, 0], | ||
| [0, 0, 0, 1, 0], | ||
| [0, 0, 0, 0, 1], | ||
| [1, 0, 0, 0, 0] | ||
| ]); | ||
| }); | ||
| it("should binarize strings", () => { | ||
| const result = binarizeData( | ||
| ["one", "two", "three"], | ||
| ["three", "one", "two"] | ||
| ); | ||
| expect(result).toEqual([ | ||
| [0, 1, 0], | ||
| [0, 0, 1], | ||
| [1, 0, 0] | ||
| ]); | ||
| }); | ||
| it("should binarize binary data", () => { | ||
| const result = binarizeData([1, 0, 1, 0], [0, 1]); | ||
| expect(result).toEqual([ | ||
| [0, 1], | ||
| [1, 0], | ||
| [0, 1], | ||
| [1, 0] | ||
| ]); | ||
| }); | ||
| }); | ||
| describe("Test calculatePerClassROCData", () => { | ||
| it("generate x,y data corresponding to fpr and tpr respectively", () => { | ||
| const result = calculatePerClassROCData( | ||
| [0.33, 0.32, 0.34, 0.29, 0.12, 0.41, 0.4, 0.39], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you add zeros after the comma to get all the points closer together (0.032, 0.033, 0.034, for example), then it will not do it right, will it? |
||
| [0, 1, 1, 0, 1, 0, 0, 0] | ||
| ); | ||
| expect(result).toEqual({ | ||
| points: [ | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 1 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 1, y: 0.6666666666666666 }, | ||
| { x: 0.8, y: 0.6666666666666666 }, | ||
| { x: 0.8, y: 0.6666666666666666 }, | ||
| { x: 0.8, y: 0.6666666666666666 }, | ||
| { x: 0.8, y: 0.3333333333333333 }, | ||
| { x: 0.6, y: 0.3333333333333333 }, | ||
| { x: 0.6, y: 0 }, | ||
| { x: 0.6, y: 0 }, | ||
| { x: 0.6, y: 0 }, | ||
| { x: 0.6, y: 0 }, | ||
| { x: 0.6, y: 0 }, | ||
| { x: 0.4, y: 0 }, | ||
| { x: 0.2, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 0 } | ||
| ] | ||
| }); | ||
| }); | ||
| }); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| // Copyright (c) Microsoft Corporation. | ||
| // Licensed under the MIT License. | ||
|
|
||
| import { localization } from "@responsible-ai/localization"; | ||
| import { SeriesOptionsType } from "highcharts"; | ||
| import { range, unzip } from "lodash"; | ||
|
|
||
| import { IDataset } from "../Interfaces/IDataset"; | ||
|
|
||
| interface IPoint { | ||
| x: number; | ||
| y: number; | ||
| } | ||
| export interface IROCData { | ||
| points: IPoint[]; | ||
| } | ||
|
|
||
| function getStaticROCData(): SeriesOptionsType[] { | ||
| return [ | ||
| { | ||
| data: [ | ||
| { x: 0, y: 0 }, | ||
| { x: 0, y: 1 }, | ||
| { x: 1, y: 1 } | ||
| ], | ||
| name: localization.Interpret.Charts.Ideal, | ||
| type: "line" | ||
| }, | ||
| { | ||
| data: [ | ||
| { x: 0, y: 0 }, | ||
| { x: 1, y: 1 } | ||
| ], | ||
| name: localization.Interpret.Charts.Random, | ||
| type: "line" | ||
| } | ||
| ]; | ||
| } | ||
|
|
||
| export function calculatePerClassROCData( | ||
| probabilityY: number[], | ||
| binY: number[] | ||
| ): IROCData { | ||
| const rocData: IROCData = { | ||
| points: [] | ||
| }; | ||
| const thresholds = range(0, 1, 0.01); | ||
|
||
| let truePositives = 0; | ||
| let falsePositives = 0; | ||
| let trueNegatives = 0; | ||
| let falseNegatives = 0; | ||
|
|
||
| for (const threshold of thresholds) { | ||
hawestra marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for (const [index, yProba] of probabilityY.entries()) { | ||
| // if the probability of predicting the positive label is greater than the | ||
| // threshold then it's a true positive. | ||
| // otherwise, it's a false positive | ||
| if (yProba < threshold) { | ||
| if (binY[index]) { | ||
| falseNegatives++; | ||
| } else { | ||
| trueNegatives++; | ||
| } | ||
| } else if (binY[index]) { | ||
| truePositives++; | ||
| } else { | ||
| falsePositives++; | ||
| } | ||
| } | ||
| addROCPoint( | ||
| truePositives, | ||
| falsePositives, | ||
| trueNegatives, | ||
| falseNegatives, | ||
| rocData | ||
| ); | ||
| truePositives = falsePositives = trueNegatives = falseNegatives = 0; | ||
hawestra marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| return rocData; | ||
| } | ||
|
|
||
| function addROCPoint( | ||
| truePositives: number, | ||
| falsePositives: number, | ||
| trueNegatives: number, | ||
| falseNegatives: number, | ||
| rocData: IROCData | ||
| ): void { | ||
| // prevent division by 0 | ||
| const totalNegatives = trueNegatives + falsePositives; | ||
| const totalPositives = truePositives + falseNegatives; | ||
| const tpr = totalPositives === 0 ? 1 : truePositives / totalPositives; | ||
| const fpr = totalNegatives === 0 ? 1 : falsePositives / totalNegatives; | ||
| rocData.points.push({ x: fpr, y: tpr }); | ||
| } | ||
|
|
||
| export function binarizeData( | ||
| yData: string[] | number[] | number[][], | ||
| classes: string[] | number[] | ||
| ): number[][] { | ||
| // binarize labels in a one-vs-all fashion according to | ||
| const yBinData: number[][] = []; | ||
| for (const yDatum of yData) { | ||
| const binaryData = classes.map((c) => { | ||
| return c === yDatum ? 1 : 0; | ||
| }); | ||
| yBinData.push(binaryData); | ||
| } | ||
| return yBinData; | ||
| } | ||
|
|
||
| // based on https://msdata.visualstudio.com/Vienna/_git/AzureMlCli?path=/src/azureml-metrics/azureml/metrics/_classification.py&version=GBmaster | ||
| export function calculateAUCData(dataset: IDataset): SeriesOptionsType[] { | ||
| if (!dataset.probability_y || !dataset.class_names) { | ||
| // TODO: show warning message | ||
| return [...getStaticROCData()]; | ||
| } | ||
|
|
||
| // temporary, replace with dataset.classnames | ||
| const cNames = [0, 1]; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should only run for binary classification, right? So we probably need a check somewhere to disable the component otherwise. For multiclass one could do one vs all but I don't know if anyone wants that.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the studio, there's an auc chart for multiclass so i assumed we'd want that (binarizeData is supposed to handle this case), but i'll discuss with Minsoo tomorrow
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's probably 1 vs all, so that makes sense. |
||
| const binTrueY = binarizeData(dataset.true_y, cNames); | ||
| console.log(binTrueY); | ||
| // transpose in order to group class data together | ||
| const perClassBinY = unzip(binTrueY); | ||
| const perClassProba = unzip(dataset.probability_y); | ||
| const data = []; | ||
| // loop through each class to calculate roc data per class | ||
| for (const [i, element] of perClassBinY.entries()) { | ||
| const classROCData = calculatePerClassROCData(perClassProba[i], element); | ||
| const classData = { | ||
| data: classROCData.points, | ||
| // TODO: check class_names length earlier ? | ||
| name: cNames ? cNames[i] : "", | ||
| type: "line" | ||
| }; | ||
| data.push(classData); | ||
| } | ||
|
|
||
| const allData = [...data, ...getStaticROCData()]; | ||
| return allData as SeriesOptionsType[]; | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -838,7 +838,9 @@ | |
| "rowIndex": "Row index", | ||
| "absoluteIndex": "Absolute index", | ||
| "xValue": "X-value", | ||
| "yValue": "Y-value" | ||
| "yValue": "Y-value", | ||
| "Ideal": "Ideal", | ||
| "Random": "Random" | ||
| }, | ||
| "Cohort": { | ||
| "_cohort.comment": "a subset of the data is called a cohort", | ||
|
|
@@ -1818,6 +1820,7 @@ | |
| "regressionDistributionPivotItem": "Target distribution", | ||
| "metricsVisualizationsPivotItem": "Metrics visualizations", | ||
| "confusionMatrixPivotItem": "Confusion matrix", | ||
| "AUCPivotItem": "AUC Chart", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine, but we'll probably need a little explainer somewhere. Especially because it's an acronym
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will discuss with Minsoo! |
||
| "disaggregatedAnalysisFeatureSelectionPlaceholder": "Select features to generate the feature-based analysis.", | ||
| "tableCountTooltip": "Cohort {0} contains {1} instances.", | ||
| "tableMetricTooltip": "The model's {0} on cohort {1} is {2}", | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.