|
7 | 7 |
|
8 | 8 | #pragma once
|
9 | 9 |
|
| 10 | +#include "MantidAPI/AlgorithmManager.h" |
10 | 11 | #include "MantidAPI/MatrixWorkspace.h"
|
11 | 12 | #include "MantidAPI/WorkspaceGroup.h"
|
12 | 13 | #include "MantidAlgorithms/DllConfig.h"
|
| 14 | +#include "MantidKernel/MultiThreaded.h" |
| 15 | +#include <Eigen/Dense> |
13 | 16 | #include <optional>
|
| 17 | +#include <unsupported/Eigen/AutoDiff> |
14 | 18 |
|
15 | 19 | namespace Mantid::Algorithms {
|
16 | 20 | namespace PolarizationCorrectionsHelpers {
|
@@ -66,4 +70,109 @@ MANTID_ALGORITHMS_DLL const std::string &getORSONotationForSpinState(const std::
|
66 | 70 | MANTID_ALGORITHMS_DLL void addORSOLogForSpinState(const Mantid::API::MatrixWorkspace_sptr &ws,
|
67 | 71 | const std::string &spinState);
|
68 | 72 | } // namespace SpinStatesORSO
|
| 73 | + |
| 74 | +namespace Arithmetic { |
| 75 | + |
| 76 | +template <size_t N> class ErrorTypeHelper { |
| 77 | +public: |
| 78 | + using DerType = Eigen::Matrix<double, N, 1>; |
| 79 | + using InputArray = DerType; |
| 80 | + using ADScalar = Eigen::AutoDiffScalar<DerType>; |
| 81 | +}; |
| 82 | + |
| 83 | +template <size_t N, typename Func> class ErrorPropagation { |
| 84 | +public: |
| 85 | + using Types = ErrorTypeHelper<N>; |
| 86 | + using DerType = Types::DerType; |
| 87 | + using ADScalar = Types::ADScalar; |
| 88 | + using InputArray = Types::InputArray; |
| 89 | + ErrorPropagation(Func func) : computeFunc(std::move(func)) {} |
| 90 | + |
| 91 | + struct AutoDevResult { |
| 92 | + double value; |
| 93 | + double error; |
| 94 | + Eigen::Array<double, N, 1> derivatives; |
| 95 | + }; |
| 96 | + |
| 97 | + AutoDevResult evaluate(const InputArray &values, const InputArray &errors) const { |
| 98 | + std::array<ADScalar, N> x; |
| 99 | + for (size_t i = 0; i < N; ++i) { |
| 100 | + x[i] = ADScalar(values[i], DerType::Unit(N, i)); |
| 101 | + } |
| 102 | + const ADScalar y = computeFunc(x); |
| 103 | + const auto &derivatives = y.derivatives(); |
| 104 | + return {y.value(), std::sqrt((derivatives.array().square() * errors.array().square()).sum()), derivatives}; |
| 105 | + } |
| 106 | + |
| 107 | + template <std::same_as<API::MatrixWorkspace_sptr>... Ts> |
| 108 | + API::MatrixWorkspace_sptr evaluateWorkspaces(const bool outputWorkspaceDistribution, Ts... args) const { |
| 109 | + return evaluateWorkspacesImpl(outputWorkspaceDistribution, std::forward<Ts>(args)...); |
| 110 | + } |
| 111 | + |
| 112 | + template <std::same_as<API::MatrixWorkspace_sptr>... Ts> |
| 113 | + API::MatrixWorkspace_sptr evaluateWorkspaces(Ts... args) const { |
| 114 | + return evaluateWorkspacesImpl(std::nullopt, std::forward<Ts>(args)...); |
| 115 | + } |
| 116 | + |
| 117 | +private: |
| 118 | + Func computeFunc; |
| 119 | + |
| 120 | + template <std::same_as<API::MatrixWorkspace_sptr>... Ts> |
| 121 | + API::MatrixWorkspace_sptr evaluateWorkspacesImpl(std::optional<bool> outputWorkspaceDistribution, Ts... args) const { |
| 122 | + const auto firstWs = std::get<0>(std::forward_as_tuple(args...)); |
| 123 | + API::MatrixWorkspace_sptr outWs = firstWs->clone(); |
| 124 | + |
| 125 | + if (outWs->id() == "EventWorkspace") { |
| 126 | + outWs = convertToWorkspace2D(outWs); |
| 127 | + } |
| 128 | + |
| 129 | + const size_t numSpec = outWs->getNumberHistograms(); |
| 130 | + const size_t specSize = outWs->blocksize(); |
| 131 | + |
| 132 | + // cppcheck-suppress unreadVariable |
| 133 | + const bool isThreadSafe = Kernel::threadSafe((*args)..., *outWs); |
| 134 | + // cppcheck-suppress unreadVariable |
| 135 | + const bool specOverBins = numSpec > specSize; |
| 136 | + |
| 137 | + PARALLEL_FOR_IF(isThreadSafe && specOverBins) |
| 138 | + for (int64_t i = 0; i < static_cast<int64_t>(numSpec); i++) { |
| 139 | + auto &yOut = outWs->mutableY(i); |
| 140 | + auto &eOut = outWs->mutableE(i); |
| 141 | + |
| 142 | + PARALLEL_FOR_IF(isThreadSafe && !specOverBins) |
| 143 | + for (int64_t j = 0; j < static_cast<int64_t>(specSize); ++j) { |
| 144 | + const auto result = evaluate(InputArray{args->y(i)[j]...}, InputArray(args->e(i)[j]...)); |
| 145 | + yOut[j] = result.value; |
| 146 | + eOut[j] = result.error; |
| 147 | + } |
| 148 | + } |
| 149 | + |
| 150 | + if (outputWorkspaceDistribution.has_value()) { |
| 151 | + outWs->setDistribution(outputWorkspaceDistribution.value()); |
| 152 | + } |
| 153 | + return outWs; |
| 154 | + } |
| 155 | + |
| 156 | + API::MatrixWorkspace_sptr runWorkspaceConversionAlg(const API::MatrixWorkspace_sptr &workspace, |
| 157 | + const std::string &algName) const { |
| 158 | + auto conversionAlg = API::AlgorithmManager::Instance().create(algName); |
| 159 | + conversionAlg->initialize(); |
| 160 | + conversionAlg->setChild(true); |
| 161 | + conversionAlg->setProperty("InputWorkspace", workspace); |
| 162 | + conversionAlg->setProperty("OutputWorkspace", workspace->getName()); |
| 163 | + conversionAlg->execute(); |
| 164 | + return conversionAlg->getProperty("OutputWorkspace"); |
| 165 | + } |
| 166 | + |
| 167 | + API::MatrixWorkspace_sptr convertToWorkspace2D(const API::MatrixWorkspace_sptr &workspace) const { |
| 168 | + runWorkspaceConversionAlg(workspace, "ConvertToHistogram"); |
| 169 | + return runWorkspaceConversionAlg(workspace, "ConvertToMatrixWorkspace"); |
| 170 | + } |
| 171 | +}; |
| 172 | + |
| 173 | +template <size_t N, typename Func> auto makeErrorPropagation(Func &&func) { |
| 174 | + return ErrorPropagation<N, std::decay_t<Func>>(std::forward<Func>(func)); |
| 175 | +} |
| 176 | + |
| 177 | +} // namespace Arithmetic |
69 | 178 | } // namespace Mantid::Algorithms
|
0 commit comments