Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -462,22 +462,32 @@ public static class Prediction implements Serializable, Comparable<Prediction>,

// The predicted value
@XmlElement
private double prediction = 0;
private double prediction = 0D;

// The confidence of the prediction.
@XmlElement
private double confidence = 0D;

public Prediction() {
super();
}

public Prediction(String name, double value) {
public Prediction(String name, double prediction) {
this(name, prediction, 0D);
}

public Prediction(String name, double prediction, double confidence) {
super();
this.name = name;
this.prediction = value;
this.prediction = prediction;
this.confidence = confidence;
}

public Prediction(Prediction o) {
super();
this.name = o.name;
this.prediction = o.prediction;
this.confidence = o.confidence;
}

public Prediction duplicate() {
Expand All @@ -500,9 +510,17 @@ public void setPrediction(double prediction) {
this.prediction = prediction;
}

public double getConfidence() {
return confidence;
}

public void setConfidence(double confidence) {
this.confidence = confidence;
}

@Override
public int hashCode() {
return new HashCodeBuilder(17, 37).append(name).append(prediction).toHashCode();
return new HashCodeBuilder(17, 37).append(name).append(prediction).append(confidence).toHashCode();
}

@Override
Expand All @@ -515,20 +533,22 @@ public boolean equals(Object o) {
}
if (o instanceof Prediction) {
Prediction other = (Prediction) o;
return new EqualsBuilder().append(this.name, other.name).append(this.prediction, other.prediction).isEquals();
return new EqualsBuilder().append(this.name, other.name).append(this.prediction, other.prediction).append(this.confidence, other.confidence)
.isEquals();
} else {
return false;
}
}

@Override
public int compareTo(Prediction o) {
return new CompareToBuilder().append(name, o.name).append(prediction, o.prediction).toComparison();
return new CompareToBuilder().append(name, o.name).append(prediction, o.prediction).append(confidence, o.confidence).toComparison();
}

@Override
public String toString() {
return new StringBuilder().append("Name: ").append(this.name).append(" Prediction: ").append(this.prediction).toString();
return new StringBuilder().append("Name: ").append(this.name).append(" Prediction: ").append(this.prediction).append(" Confidence: ")
.append(this.confidence).toString();
}

public static Schema<Prediction> getSchema() {
Expand All @@ -540,7 +560,7 @@ public Schema<Prediction> cachedSchema() {
return SCHEMA;
}

private static final Schema<Prediction> SCHEMA = new Schema<Prediction>() {
private static final Schema<Prediction> SCHEMA = new Schema<>() {
public Prediction newMessage() {
return new Prediction();
}
Expand All @@ -564,6 +584,7 @@ public boolean isInitialized(Prediction message) {
public void writeTo(Output output, Prediction message) throws IOException {
output.writeString(1, message.name, false);
output.writeDouble(2, message.prediction, false);
output.writeDouble(3, message.confidence, false);
}

public void mergeFrom(Input input, Prediction message) throws IOException {
Expand All @@ -576,6 +597,9 @@ public void mergeFrom(Input input, Prediction message) throws IOException {
case 2:
message.prediction = input.readDouble();
break;
case 3:
message.confidence = input.readDouble();
break;
default:
input.handleUnknownField(number, this);
break;
Expand All @@ -589,21 +613,24 @@ public String getFieldName(int number) {
return "name";
case 2:
return "prediction";
case 3:
return "confidence";
default:
return null;
}
}

public int getFieldNumber(String name) {
final Integer number = fieldMap.get(name);
return number == null ? 0 : number.intValue();
return number == null ? 0 : number;
}

final HashMap<String,Integer> fieldMap = new LinkedHashMap<>();

{
fieldMap.put("name", 1);
fieldMap.put("prediction", 2);
fieldMap.put("confidence", 3);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ default String getPredictionsStr(Set<BaseQueryMetric.Prediction> predictions) {
List<BaseQueryMetric.Prediction> predictionsList = new ArrayList<>(predictions);
Collections.sort(predictionsList);
for (BaseQueryMetric.Prediction p : predictionsList) {
builder.append(delimiter).append(p.getName()).append(" = ").append(p.getPrediction());
builder.append(delimiter).append(p.getName()).append(" = ").append(p.getPrediction()).append(":").append(p.getConfidence());
delimiter = "<br>";
}
} else {
Expand Down