Skip to content

Commit db4f29d

Browse files
authored
Merge pull request #38075 from mantidproject/ewm7196-nans-flag
Add flag to `CompareWorkspaces` so users can specify `NaN == NaN` behavior
2 parents f94d911 + b2acf19 commit db4f29d

38 files changed

+598
-194
lines changed

Framework/API/inc/MantidAPI/Column.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,13 @@ class MANTID_API_DLL Column {
157157
return vec;
158158
}
159159

160-
virtual bool equals(const Column &, double) const { throw std::runtime_error("equals not implemented"); };
160+
virtual bool equals(const Column &, double, bool const = false) const {
161+
throw std::runtime_error("equals not implemented");
162+
};
161163

162-
virtual bool equalsRelErr(const Column &, double) const { throw std::runtime_error("equals not implemented"); };
164+
virtual bool equalsRelErr(const Column &, double, bool const = false) const {
165+
throw std::runtime_error("equals not implemented");
166+
};
163167

164168
protected:
165169
/// Sets the new column size.

Framework/Algorithms/inc/MantidAlgorithms/CompareWorkspaces.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ class MANTID_ALGORITHMS_DLL CompareWorkspaces final : public API::Algorithm {
7676
"testing process.";
7777
}
7878

79-
static bool withinAbsoluteTolerance(double x1, double x2, double atol);
80-
static bool withinRelativeTolerance(double x1, double x2, double rtol);
79+
static bool withinAbsoluteTolerance(double x1, double x2, double atol, bool const nanEqual = false);
80+
static bool withinRelativeTolerance(double x1, double x2, double rtol, bool const nanEqual = false);
8181

8282
private:
8383
/// Initialise algorithm

Framework/Algorithms/src/CompareWorkspaces.cpp

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "MantidGeometry/Crystal/IPeak.h"
2323
#include "MantidGeometry/Instrument/ComponentInfo.h"
2424
#include "MantidGeometry/Instrument/DetectorInfo.h"
25+
#include "MantidKernel/FloatingPointComparison.h"
2526
#include "MantidKernel/Unit.h"
2627

2728
namespace Mantid::Algorithms {
@@ -71,23 +72,18 @@ int compareEventLists(Kernel::Logger &logger, const EventList &el1, const EventL
7172
const auto &e1 = events1[i];
7273
const auto &e2 = events2[i];
7374

74-
bool diffpulse = false;
75-
bool difftof = false;
76-
bool diffweight = false;
77-
if (std::abs(e1.pulseTime().totalNanoseconds() - e2.pulseTime().totalNanoseconds()) > tolPulse) {
78-
diffpulse = true;
79-
++numdiffpulse;
80-
}
81-
if (std::abs(e1.tof() - e2.tof()) > tolTof) {
82-
difftof = true;
83-
++numdifftof;
84-
}
75+
bool diffpulse =
76+
!withinAbsoluteDifference(e1.pulseTime().totalNanoseconds(), e2.pulseTime().totalNanoseconds(), tolPulse);
77+
bool difftof = !withinAbsoluteDifference(e1.tof(), e2.tof(), tolTof);
78+
bool diffweight = !withinAbsoluteDifference(e1.weight(), e2.weight(), tolWeight);
8579
if (diffpulse && difftof)
86-
++numdiffboth;
87-
if (std::abs(e1.weight() - e2.weight()) > tolWeight) {
88-
diffweight = true;
89-
++numdiffweight;
90-
}
80+
numdiffboth++;
81+
if (diffpulse)
82+
numdiffpulse++;
83+
if (difftof)
84+
numdifftof++;
85+
if (diffweight)
86+
numdiffweight++;
9187

9288
bool same = (!diffpulse) && (!difftof) && (!diffweight);
9389
if (!same) {
@@ -148,6 +144,8 @@ void CompareWorkspaces::init() {
148144
"Very often such logs are huge so making it true should be "
149145
"the last option.");
150146

147+
declareProperty("NaNsEqual", false, "Whether NaN values should compare as equal with other NaN values.");
148+
151149
declareProperty("NumberMismatchedSpectraToPrint", 1, "Number of mismatched spectra from lowest to be listed. ");
152150

153151
declareProperty("DetailedPrintIndex", EMPTY_INT(), "Mismatched spectra that will be printed out in details. ");
@@ -172,13 +170,14 @@ void CompareWorkspaces::exec() {
172170
m_parallelComparison = false;
173171

174172
double const tolerance = getProperty("Tolerance");
173+
bool const nanEqual = getProperty("NaNsEqual");
175174
if (getProperty("ToleranceRelErr")) {
176-
this->m_compare = [tolerance](double const x1, double const x2) -> bool {
177-
return CompareWorkspaces::withinRelativeTolerance(x1, x2, tolerance);
175+
this->m_compare = [tolerance, nanEqual](double const x1, double const x2) -> bool {
176+
return CompareWorkspaces::withinRelativeTolerance(x1, x2, tolerance, nanEqual);
178177
};
179178
} else {
180-
this->m_compare = [tolerance](double const x1, double const x2) -> bool {
181-
return CompareWorkspaces::withinAbsoluteTolerance(x1, x2, tolerance);
179+
this->m_compare = [tolerance, nanEqual](double const x1, double const x2) -> bool {
180+
return CompareWorkspaces::withinAbsoluteTolerance(x1, x2, tolerance, nanEqual);
182181
};
183182
}
184183

@@ -1049,10 +1048,11 @@ void CompareWorkspaces::doPeaksComparison(PeaksWorkspace_sptr tws1, PeaksWorkspa
10491048
}
10501049

10511050
const bool isRelErr = getProperty("ToleranceRelErr");
1051+
const bool checkAllData = getProperty("CheckAllData");
10521052
for (int i = 0; i < tws1->getNumberPeaks(); i++) {
10531053
const Peak &peak1 = tws1->getPeak(i);
10541054
const Peak &peak2 = tws2->getPeak(i);
1055-
for (size_t j = 0; j < tws1->columnCount(); j++) {
1055+
for (std::size_t j = 0; j < tws1->columnCount(); j++) {
10561056
std::shared_ptr<const API::Column> col = tws1->getColumn(j);
10571057
std::string name = col->name();
10581058
double s1 = 0.0;
@@ -1127,7 +1127,8 @@ void CompareWorkspaces::doPeaksComparison(PeaksWorkspace_sptr tws1, PeaksWorkspa
11271127
<< "value1 = " << s1 << "\n"
11281128
<< "value2 = " << s2 << "\n";
11291129
recordMismatch("Data mismatch");
1130-
return;
1130+
if (!checkAllData)
1131+
return;
11311132
}
11321133
}
11331134
}
@@ -1163,8 +1164,10 @@ void CompareWorkspaces::doLeanElasticPeaksComparison(const LeanElasticPeaksWorks
11631164

11641165
const double tolerance = getProperty("Tolerance");
11651166
const bool isRelErr = getProperty("ToleranceRelErr");
1167+
const bool checkAllData = getProperty("CheckAllData");
1168+
const bool nanEqual = getProperty("NaNsEqual");
11661169
for (int peakIndex = 0; peakIndex < ipws1->getNumberPeaks(); peakIndex++) {
1167-
for (size_t j = 0; j < ipws1->columnCount(); j++) {
1170+
for (std::size_t j = 0; j < ipws1->columnCount(); j++) {
11681171
std::shared_ptr<const API::Column> col = ipws1->getColumn(j);
11691172
const std::string name = col->name();
11701173
double s1 = 0.0;
@@ -1229,10 +1232,10 @@ void CompareWorkspaces::doLeanElasticPeaksComparison(const LeanElasticPeaksWorks
12291232
// bool mismatch = !m_compare(s1, s2)
12301233
// can replace this if/else, and isRelErr and tolerance can be deleted
12311234
if (isRelErr && name != "QLab" && name != "QSample") {
1232-
if (!withinRelativeTolerance(s1, s2, tolerance)) {
1235+
if (!withinRelativeTolerance(s1, s2, tolerance, nanEqual)) {
12331236
mismatch = true;
12341237
}
1235-
} else if (!withinAbsoluteTolerance(s1, s2, tolerance)) {
1238+
} else if (!withinAbsoluteTolerance(s1, s2, tolerance, nanEqual)) {
12361239
mismatch = true;
12371240
}
12381241
if (mismatch) {
@@ -1242,7 +1245,8 @@ void CompareWorkspaces::doLeanElasticPeaksComparison(const LeanElasticPeaksWorks
12421245
<< "value1 = " << s1 << "\n"
12431246
<< "value2 = " << s2 << "\n";
12441247
recordMismatch("Data mismatch");
1245-
return;
1248+
if (!checkAllData)
1249+
return;
12461250
}
12471251
}
12481252
}
@@ -1283,19 +1287,23 @@ void CompareWorkspaces::doTableComparison(const API::ITableWorkspace_const_sptr
12831287

12841288
const bool checkAllData = getProperty("CheckAllData");
12851289
const bool isRelErr = getProperty("ToleranceRelErr");
1290+
const bool nanEqual = getProperty("NaNsEqual");
12861291
const double tolerance = getProperty("Tolerance");
12871292
bool mismatch;
1288-
for (size_t i = 0; i < numCols; ++i) {
1293+
for (std::size_t i = 0; i < numCols; ++i) {
12891294
const auto c1 = tws1->getColumn(i);
12901295
const auto c2 = tws2->getColumn(i);
12911296

12921297
if (isRelErr) {
1293-
mismatch = !c1->equalsRelErr(*c2, tolerance);
1298+
mismatch = !c1->equalsRelErr(*c2, tolerance, nanEqual);
12941299
} else {
1295-
mismatch = !c1->equals(*c2, tolerance);
1300+
mismatch = !c1->equals(*c2, tolerance, nanEqual);
12961301
}
12971302
if (mismatch) {
12981303
g_log.debug() << "Table data mismatch at column " << i << "\n";
1304+
for (std::size_t j = 0; j < c1->size(); j++) {
1305+
g_log.debug() << "\t" << j << " | " << c1->cell<double>(j) << ", " << c2->cell<double>(j) << "\n";
1306+
}
12991307
recordMismatch("Table data mismatch");
13001308
if (!checkAllData) {
13011309
return;
@@ -1356,12 +1364,15 @@ this error is within the limits requested.
13561364
@param x1 -- first value to check difference
13571365
@param x2 -- second value to check difference
13581366
@param atol -- the tolerance of the comparison. Must be nonnegative
1367+
@param nanEqual -- whether two NaNs compare as equal
13591368
13601369
@returns true if absolute difference is within the tolerance; false otherwise
13611370
*/
1362-
bool CompareWorkspaces::withinAbsoluteTolerance(double const x1, double const x2, double const atol) {
1363-
// NOTE !(|x1-x2| > atol) is not the same as |x1-x2| <= atol
1364-
return !(std::abs(x1 - x2) > atol);
1371+
bool CompareWorkspaces::withinAbsoluteTolerance(double const x1, double const x2, double const atol,
1372+
bool const nanEqual) {
1373+
if (nanEqual && std::isnan(x1) && std::isnan(x2))
1374+
return true;
1375+
return Kernel::withinAbsoluteDifference(x1, x2, atol);
13651376
}
13661377

13671378
//------------------------------------------------------------------------------------------------
@@ -1371,24 +1382,15 @@ this error is within the limits requested.
13711382
@param x1 -- first value to check difference
13721383
@param x2 -- second value to check difference
13731384
@param rtol -- the tolerance of the comparison. Must be nonnegative
1385+
@param nanEqual -- whether two NaNs compare as equal
13741386
13751387
@returns true if relative difference is within the tolerance; false otherwise
13761388
@returns true if error or false if the relative value is within the limits requested
13771389
*/
1378-
bool CompareWorkspaces::withinRelativeTolerance(double const x1, double const x2, double const rtol) {
1379-
// calculate difference
1380-
double const num = std::abs(x1 - x2);
1381-
// return early if the values are equal
1382-
if (num == 0.0)
1390+
bool CompareWorkspaces::withinRelativeTolerance(double const x1, double const x2, double const rtol,
1391+
bool const nanEqual) {
1392+
if (nanEqual && std::isnan(x1) && std::isnan(x2))
13831393
return true;
1384-
// create the average magnitude for comparison
1385-
double const den = 0.5 * (std::abs(x1) + std::abs(x2));
1386-
// return early, possibly avoids a multiplication
1387-
// NOTE if den<1, then divsion will only make num larger
1388-
// NOTE if den<1 but num<=rtol, we cannot conclude anything
1389-
if (den <= 1.0 && num > rtol)
1390-
return false;
1391-
// NOTE !(num > rtol*den) is not the same as (num <= rtol*den)
1392-
return !(num > (rtol * den));
1394+
return Kernel::withinRelativeDifference(x1, x2, rtol);
13931395
}
13941396
} // namespace Mantid::Algorithms

Framework/Algorithms/src/DetectorEfficiencyCor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ void DetectorEfficiencyCor::exec() {
123123
int64_t numHists = m_inputWS->getNumberHistograms();
124124
auto numHists_d = static_cast<double>(numHists);
125125
const auto progStep = static_cast<int64_t>(ceil(numHists_d / 100.0));
126-
auto &spectrumInfo = m_inputWS->spectrumInfo();
126+
auto const &spectrumInfo = m_inputWS->spectrumInfo();
127127

128128
PARALLEL_FOR_IF(Kernel::threadSafe(*m_inputWS, *m_outputWS))
129129
for (int64_t i = 0; i < numHists; ++i) {

Framework/Algorithms/test/CompareWorkspacesTest.h

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,44 @@ class CompareWorkspacesTest : public CxxTest::TestSuite {
216216
checker.resetProperties();
217217
}
218218

219+
void test_NaNsEqual_true() {
220+
if (!checker.isInitialized())
221+
checker.initialize();
222+
223+
double const anan = std::numeric_limits<double>::quiet_NaN();
224+
225+
// a real and NaN are never equal
226+
WorkspaceSingleValue_sptr ws1 = WorkspaceCreationHelper::createWorkspaceSingleValue(1.1);
227+
WorkspaceSingleValue_sptr ws2 = WorkspaceCreationHelper::createWorkspaceSingleValue(anan);
228+
// is not equal if NaNsEqual set true
229+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("NaNsEqual", true));
230+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace1", ws1));
231+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace2", ws2));
232+
TS_ASSERT(checker.execute());
233+
TS_ASSERT_EQUALS(checker.getPropertyValue("Result"), PROPERTY_VALUE_FALSE);
234+
// is not equal if NaNsEqual set false
235+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("NaNsEqual", false));
236+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace1", ws1));
237+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace2", ws2));
238+
TS_ASSERT(checker.execute());
239+
TS_ASSERT_EQUALS(checker.getPropertyValue("Result"), PROPERTY_VALUE_FALSE);
240+
241+
// NaNs only compare equal if flag set
242+
WorkspaceSingleValue_sptr ws3 = WorkspaceCreationHelper::createWorkspaceSingleValue(anan);
243+
// is NOT equal if NaNsEqual set FALSE
244+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("NaNsEqual", false));
245+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace1", ws2));
246+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace2", ws3));
247+
TS_ASSERT(checker.execute());
248+
TS_ASSERT_EQUALS(checker.getPropertyValue("Result"), PROPERTY_VALUE_FALSE);
249+
// ARE equal if NaNsEqual set TRUE
250+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("NaNsEqual", true));
251+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace1", ws2));
252+
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace2", ws3));
253+
TS_ASSERT(checker.execute());
254+
TS_ASSERT_EQUALS(checker.getPropertyValue("Result"), PROPERTY_VALUE_TRUE);
255+
}
256+
219257
void testPeaks_matches() {
220258
if (!checker.isInitialized())
221259
checker.initialize();
@@ -1193,6 +1231,83 @@ class CompareWorkspacesTest : public CxxTest::TestSuite {
11931231
TS_ASSERT_EQUALS(alg.getPropertyValue("Result"), PROPERTY_VALUE_TRUE);
11941232
}
11951233

1234+
void test_equal_tableworkspaces_match() {
1235+
std::string const col_type("double"), col_name("aColumn");
1236+
std::vector<double> col_values{1.0, 2.0, 3.0};
1237+
// create the table workspaces
1238+
Mantid::API::ITableWorkspace_sptr table1 = WorkspaceFactory::Instance().createTable();
1239+
table1->addColumn(col_type, col_name);
1240+
for (double val : col_values) {
1241+
TableRow newrow = table1->appendRow();
1242+
newrow << val;
1243+
}
1244+
Mantid::API::ITableWorkspace_sptr table2 = WorkspaceFactory::Instance().createTable();
1245+
table2->addColumn(col_type, col_name);
1246+
for (double val : col_values) {
1247+
TableRow newrow = table2->appendRow();
1248+
newrow << val;
1249+
}
1250+
1251+
Mantid::Algorithms::CompareWorkspaces alg;
1252+
alg.initialize();
1253+
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace1", table1));
1254+
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace2", table2));
1255+
TS_ASSERT(alg.execute());
1256+
TS_ASSERT_EQUALS(alg.getPropertyValue("Result"), PROPERTY_VALUE_TRUE);
1257+
}
1258+
1259+
void test_tableworkspace_NaNs_passes_with_flag() {
1260+
std::string const col_type("double"), col_name("aColumn");
1261+
std::vector<double> col_values{1.0, 2.0, std::numeric_limits<double>::quiet_NaN()};
1262+
// create the table workspaces
1263+
Mantid::API::ITableWorkspace_sptr table1 = WorkspaceFactory::Instance().createTable();
1264+
Mantid::API::ITableWorkspace_sptr table2 = WorkspaceFactory::Instance().createTable();
1265+
table1->addColumn(col_type, col_name);
1266+
table2->addColumn(col_type, col_name);
1267+
for (double val : col_values) {
1268+
TableRow newrow1 = table1->appendRow();
1269+
newrow1 << val;
1270+
TableRow newrow2 = table2->appendRow();
1271+
newrow2 << val;
1272+
}
1273+
Mantid::Algorithms::CompareWorkspaces alg;
1274+
alg.initialize();
1275+
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace1", table1));
1276+
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace2", table2));
1277+
TS_ASSERT_THROWS_NOTHING(alg.setProperty("NaNsEqual", true));
1278+
TS_ASSERT(alg.execute());
1279+
TS_ASSERT_EQUALS(alg.getPropertyValue("Result"), PROPERTY_VALUE_TRUE);
1280+
}
1281+
1282+
void test_tableworkspace_NaNs_fails() {
1283+
std::string const col_type("double"), col_name("aColumn");
1284+
std::vector<double> col_values1{1.0, 2.0, 3.0};
1285+
std::vector<double> col_values2{1.0, 2.0, std::numeric_limits<double>::quiet_NaN()};
1286+
// create the table workspaces
1287+
Mantid::API::ITableWorkspace_sptr table1 = WorkspaceFactory::Instance().createTable();
1288+
table1->addColumn(col_type, col_name);
1289+
for (double val : col_values1) {
1290+
TableRow newrow = table1->appendRow();
1291+
newrow << val;
1292+
}
1293+
Mantid::API::ITableWorkspace_sptr table2 = WorkspaceFactory::Instance().createTable();
1294+
table2->addColumn(col_type, col_name);
1295+
for (double val : col_values2) {
1296+
TableRow newrow = table2->appendRow();
1297+
newrow << val;
1298+
}
1299+
1300+
Mantid::Algorithms::CompareWorkspaces alg;
1301+
alg.initialize();
1302+
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace1", table1));
1303+
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace2", table2));
1304+
TS_ASSERT(alg.execute());
1305+
TS_ASSERT_EQUALS(alg.getPropertyValue("Result"), PROPERTY_VALUE_FALSE);
1306+
1307+
ITableWorkspace_sptr table = AnalysisDataService::Instance().retrieveWS<TableWorkspace>("compare_msgs");
1308+
TS_ASSERT_EQUALS(table->cell<std::string>(0, 0), "Table data mismatch");
1309+
}
1310+
11961311
void test_tableworkspace_different_column_names_fails() {
11971312
auto table1 = setupTableWorkspace();
11981313
table1->getColumn(5)->setName("SomethingElse");

0 commit comments

Comments
 (0)