Skip to content

Use math solver dispatch to separate math solver from main model #916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f53e281
remove instantiation
TonyXiang8787 Mar 13, 2025
1fea6c5
start dispatcher
TonyXiang8787 Mar 13, 2025
4bae0c7
create
TonyXiang8787 Mar 13, 2025
892965f
proxy object
TonyXiang8787 Mar 13, 2025
3f8c6c9
destroy
TonyXiang8787 Mar 13, 2025
7353255
solver outside
TonyXiang8787 Mar 13, 2025
429aaee
try to arrange main model
TonyXiang8787 Mar 13, 2025
5507528
arrange main model wrapper
TonyXiang8787 Mar 13, 2025
513f2e6
allocate c api
TonyXiang8787 Mar 13, 2025
e2d05b3
fix some errors'
TonyXiang8787 Mar 13, 2025
cfcc5ac
test passes
TonyXiang8787 Mar 13, 2025
f07f84b
also remove ybus
TonyXiang8787 Mar 14, 2025
cd134de
forward declare YBus
TonyXiang8787 Mar 14, 2025
ef71596
separate math solver
TonyXiang8787 Mar 14, 2025
b12a1cb
add math solver
TonyXiang8787 Mar 14, 2025
de5f174
fix clang format
TonyXiang8787 Mar 14, 2025
aa363c1
Merge branch 'main' into math-solver-dispatch
TonyXiang8787 Mar 14, 2025
19061b3
fix clang-tidy
TonyXiang8787 Mar 14, 2025
0dbd300
resolve comments
TonyXiang8787 Mar 16, 2025
b3b057b
use const ref
TonyXiang8787 Mar 16, 2025
9aa1478
Merge branch 'main' into math-solver-dispatch
TonyXiang8787 Mar 17, 2025
a3bb539
[skip ci] format
TonyXiang8787 Mar 17, 2025
a4d1ee5
Merge branch 'main' into math-solver-dispatch
TonyXiang8787 Mar 17, 2025
b293f93
[skip ci] try to use abstract base class
TonyXiang8787 Mar 17, 2025
2dae0da
[skip ci] inheritance
TonyXiang8787 Mar 17, 2025
dbd5aa2
finish using abstract base class
TonyXiang8787 Mar 17, 2025
1d70d09
try to fix concept
TonyXiang8787 Mar 17, 2025
7cce637
Update power_grid_model_c/power_grid_model/include/power_grid_model/m…
TonyXiang8787 Mar 17, 2025
da95e63
format
TonyXiang8787 Mar 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
#pragma once

#include "../math_solver/y_bus.hpp"
#include "../math_solver/math_solver_dispatch.hpp"

namespace power_grid_model::main_core {

struct MathState {
std::vector<YBus<symmetric_t>> y_bus_vec_sym;
std::vector<YBus<asymmetric_t>> y_bus_vec_asym;
std::vector<MathSolver<symmetric_t>> math_solvers_sym;
std::vector<MathSolver<asymmetric_t>> math_solvers_asym;
std::vector<MathSolverProxy<symmetric_t>> math_solvers_sym;
std::vector<MathSolverProxy<asymmetric_t>> math_solvers_asym;
};

inline void clear(MathState& math_state) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ class MainModel {
public:
using Options = MainModelOptions;

explicit MainModel(double system_frequency, ConstDataset const& input_data, Idx pos = 0)
: impl_{std::make_unique<Impl>(system_frequency, input_data, pos)} {}
explicit MainModel(double system_frequency, meta_data::MetaData const& meta_data)
: impl_{std::make_unique<Impl>(system_frequency, meta_data)} {};
explicit MainModel(double system_frequency, ConstDataset const& input_data,
MathSolverDispatcher const* math_solver_dispatcher, Idx pos = 0)
: impl_{std::make_unique<Impl>(system_frequency, input_data, math_solver_dispatcher, pos)} {}
explicit MainModel(double system_frequency, meta_data::MetaData const& meta_data,
MathSolverDispatcher const* math_solver_dispatcher)
: impl_{std::make_unique<Impl>(system_frequency, meta_data, math_solver_dispatcher)} {};

// deep copy
MainModel(MainModel const& other) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "auxiliary/output.hpp"

// math model include
#include "math_solver/math_solver.hpp"
#include "math_solver/math_solver_dispatch.hpp"

#include "optimizer/optimizer.hpp"

Expand Down Expand Up @@ -161,16 +161,22 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
using Options = MainModelOptions;

// constructor with data
explicit MainModelImpl(double system_frequency, ConstDataset const& input_data, Idx pos = 0)
: system_frequency_{system_frequency}, meta_data_{&input_data.meta_data()} {
explicit MainModelImpl(double system_frequency, ConstDataset const& input_data,
MathSolverDispatcher const* math_solver_dispatcher, Idx pos = 0)
: system_frequency_{system_frequency},
meta_data_{&input_data.meta_data()},
math_solver_dispatcher_{math_solver_dispatcher} {
assert(input_data.get_description().dataset->name == std::string_view("input"));
add_components(input_data, pos);
set_construction_complete();
}

// constructor with only frequency
explicit MainModelImpl(double system_frequency, meta_data::MetaData const& meta_data)
: system_frequency_{system_frequency}, meta_data_{&meta_data} {}
explicit MainModelImpl(double system_frequency, meta_data::MetaData const& meta_data,
MathSolverDispatcher const* math_solver_dispatcher)
: system_frequency_{system_frequency},
meta_data_{&meta_data},
math_solver_dispatcher_{math_solver_dispatcher} {}

private:
// helper function to get what components are present in the update data
Expand Down Expand Up @@ -441,9 +447,9 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
template <symmetry_tag sym> auto calculate_power_flow_(double err_tol, Idx max_iter) {
return [this, err_tol, max_iter](MainModelState const& state,
CalculationMethod calculation_method) -> std::vector<SolverOutput<sym>> {
return calculate_<SolverOutput<sym>, MathSolver<sym>, YBus<sym>, PowerFlowInput<sym>>(
return calculate_<SolverOutput<sym>, MathSolverProxy<sym>, YBus<sym>, PowerFlowInput<sym>>(
[&state](Idx n_math_solvers) { return prepare_power_flow_input<sym>(state, n_math_solvers); },
[this, err_tol, max_iter, calculation_method](MathSolver<sym>& solver, YBus<sym> const& y_bus,
[this, err_tol, max_iter, calculation_method](MathSolverProxy<sym>& solver, YBus<sym> const& y_bus,
PowerFlowInput<sym> const& input) {
return solver.run_power_flow(input, err_tol, max_iter, calculation_info_, calculation_method,
y_bus);
Expand All @@ -454,9 +460,9 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
template <symmetry_tag sym> auto calculate_state_estimation_(double err_tol, Idx max_iter) {
return [this, err_tol, max_iter](MainModelState const& state,
CalculationMethod calculation_method) -> std::vector<SolverOutput<sym>> {
return calculate_<SolverOutput<sym>, MathSolver<sym>, YBus<sym>, StateEstimationInput<sym>>(
return calculate_<SolverOutput<sym>, MathSolverProxy<sym>, YBus<sym>, StateEstimationInput<sym>>(
[&state](Idx n_math_solvers) { return prepare_state_estimation_input<sym>(state, n_math_solvers); },
[this, err_tol, max_iter, calculation_method](MathSolver<sym>& solver, YBus<sym> const& y_bus,
[this, err_tol, max_iter, calculation_method](MathSolverProxy<sym>& solver, YBus<sym> const& y_bus,
StateEstimationInput<sym> const& input) {
return solver.run_state_estimation(input, err_tol, max_iter, calculation_info_, calculation_method,
y_bus);
Expand All @@ -468,12 +474,12 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
return [this,
voltage_scaling](MainModelState const& /*state*/,
CalculationMethod calculation_method) -> std::vector<ShortCircuitSolverOutput<sym>> {
return calculate_<ShortCircuitSolverOutput<sym>, MathSolver<sym>, YBus<sym>, ShortCircuitInput>(
return calculate_<ShortCircuitSolverOutput<sym>, MathSolverProxy<sym>, YBus<sym>, ShortCircuitInput>(
[this, voltage_scaling](Idx /* n_math_solvers */) {
assert(is_topology_up_to_date_ && is_parameter_up_to_date<sym>());
return prepare_short_circuit_input<sym>(voltage_scaling);
},
[this, calculation_method](MathSolver<sym>& solver, YBus<sym> const& y_bus,
[this, calculation_method](MathSolverProxy<sym>& solver, YBus<sym> const& y_bus,
ShortCircuitInput const& input) {
return solver.run_short_circuit(input, calculation_info_, calculation_method, y_bus);
});
Expand Down Expand Up @@ -832,6 +838,7 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis

double system_frequency_;
meta_data::MetaData const* meta_data_;
MathSolverDispatcher const* math_solver_dispatcher_;

MainModelState state_;
// math model
Expand Down Expand Up @@ -859,7 +866,7 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}
}

template <symmetry_tag sym> std::vector<MathSolver<sym>>& get_solvers() {
template <symmetry_tag sym> std::vector<MathSolverProxy<sym>>& get_solvers() {
if constexpr (is_symmetric_v<sym>) {
return math_state_.math_solvers_sym;
} else {
Expand Down Expand Up @@ -1071,7 +1078,7 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
* The default lambda `include_all` always returns `true`.
*/
template <calculation_input_type CalcStructOut, typename CalcParamOut,
std::vector<CalcParamOut>(CalcStructOut::*comp_vect), class ComponentIn,
std::vector<CalcParamOut>(CalcStructOut::* comp_vect), class ComponentIn,
std::invocable<Idx> PredicateIn = IncludeAll>
requires std::convertible_to<std::invoke_result_t<PredicateIn, Idx>, bool>
static void prepare_input(MainModelState const& state, std::vector<Idx2D> const& components,
Expand All @@ -1090,7 +1097,7 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}

template <calculation_input_type CalcStructOut, typename CalcParamOut,
std::vector<CalcParamOut>(CalcStructOut::*comp_vect), class ComponentIn,
std::vector<CalcParamOut>(CalcStructOut::* comp_vect), class ComponentIn,
std::invocable<Idx> PredicateIn = IncludeAll>
requires std::convertible_to<std::invoke_result_t<PredicateIn, Idx>, bool>
static void prepare_input(MainModelState const& state, std::vector<Idx2D> const& components,
Expand All @@ -1110,7 +1117,7 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}
}

template <symmetry_tag sym, IntSVector(StateEstimationInput<sym>::*component), class Component>
template <symmetry_tag sym, IntSVector(StateEstimationInput<sym>::* component), class Component>
static void prepare_input_status(MainModelState const& state, std::vector<Idx2D> const& objects,
std::vector<StateEstimationInput<sym>>& input) {
for (Idx i = 0, n = narrow_cast<Idx>(objects.size()); i != n; ++i) {
Expand Down Expand Up @@ -1291,7 +1298,7 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis
}

template <symmetry_tag sym> void prepare_solvers() {
std::vector<MathSolver<sym>>& solvers = get_solvers<sym>();
std::vector<MathSolverProxy<sym>>& solvers = get_solvers<sym>();
// rebuild topology if needed
if (!is_topology_up_to_date_) {
rebuild_topology();
Expand All @@ -1305,8 +1312,9 @@ class MainModelImpl<ExtraRetrievableTypes<ExtraRetrievableType...>, ComponentLis

solvers.clear();
solvers.reserve(n_math_solvers_);
std::ranges::transform(state_.math_topology, std::back_inserter(solvers),
[](auto math_topo) { return MathSolver<sym>{std::move(math_topo)}; });
std::ranges::transform(state_.math_topology, std::back_inserter(solvers), [this](auto math_topo) {
return MathSolverProxy<sym>{math_solver_dispatcher_, std::move(math_topo)};
});
for (Idx idx = 0; idx < n_math_solvers_; ++idx) {
get_y_bus<sym>()[idx].register_parameters_changed_callback(
[solver = std::ref(solvers[idx])](bool changed) { solver.get().parameters_changed(changed); });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,6 @@ class IterativeCurrentPFSolver : public IterativePFSolver<sym_type, IterativeCur
}
};

template class IterativeCurrentPFSolver<symmetric_t>;
template class IterativeCurrentPFSolver<asymmetric_t>;

} // namespace iterative_current_pf

using iterative_current_pf::IterativeCurrentPFSolver;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,6 @@ template <symmetry_tag sym_type> class IterativeLinearSESolver {
}
};

template class IterativeLinearSESolver<symmetric_t>;
template class IterativeLinearSESolver<asymmetric_t>;

} // namespace iterative_linear_se

using iterative_linear_se::IterativeLinearSESolver;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ template <symmetry_tag sym_type> class LinearPFSolver {
}
};

template class LinearPFSolver<symmetric_t>;
template class LinearPFSolver<asymmetric_t>;
} // namespace linear_pf

using linear_pf::LinearPFSolver;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,6 @@ template <symmetry_tag sym> class MathSolver {
}
};

template class MathSolver<symmetric_t>;
template class MathSolver<asymmetric_t>;
} // namespace math_solver

using math_solver::MathSolver;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// SPDX-FileCopyrightText: Contributors to the Power Grid Model project <powergridmodel@lfenergy.org>
//
// SPDX-License-Identifier: MPL-2.0

// runtime dispatch for math solver
// this can separate math solver into a different translation unit

#pragma once

#include "../calculation_parameters.hpp"
#include "../common/common.hpp"
#include "../common/exception.hpp"
#include "../common/three_phase_tensor.hpp"
#include "../common/timer.hpp"

#include <memory>

namespace power_grid_model {

namespace math_solver {

// forward declare YBUs
template <symmetry_tag sym> class YBus;

template <template <symmetry_tag> class MathSolverType> struct math_solver_tag {};

struct MathSolverDispatcher {
template <symmetry_tag sym> struct MathSolverDispatcherConfig {
template <template <symmetry_tag> class MathSolverType>
constexpr MathSolverDispatcherConfig(math_solver_tag<MathSolverType>)
: create{[](std::shared_ptr<MathModelTopology const> const& topo_ptr) -> void* {
return new MathSolverType<sym>{topo_ptr};
}},
destroy{[](void const* solver) { delete reinterpret_cast<MathSolverType<sym> const*>(solver); }},
copy{[](void const* solver) -> void* {
return new MathSolverType<sym>{*reinterpret_cast<MathSolverType<sym> const*>(solver)};
}},
run_power_flow{[](void* solver, PowerFlowInput<sym> const& input, double err_tol, Idx max_iter,
CalculationInfo& calculation_info, CalculationMethod calculation_method,
YBus<sym> const& y_bus) {
return reinterpret_cast<MathSolverType<sym>*>(solver)->run_power_flow(
input, err_tol, max_iter, calculation_info, calculation_method, y_bus);
}},
run_state_estimation{[](void* solver, StateEstimationInput<sym> const& input, double err_tol,
Idx max_iter, CalculationInfo& calculation_info,
CalculationMethod calculation_method, YBus<sym> const& y_bus) {
return reinterpret_cast<MathSolverType<sym>*>(solver)->run_state_estimation(
input, err_tol, max_iter, calculation_info, calculation_method, y_bus);
}},
run_short_circuit{[](void* solver, ShortCircuitInput const& input, CalculationInfo& calculation_info,
CalculationMethod calculation_method, YBus<sym> const& y_bus) {
return reinterpret_cast<MathSolverType<sym>*>(solver)->run_short_circuit(input, calculation_info,
calculation_method, y_bus);
}},
clear_solver{[](void* solver) { reinterpret_cast<MathSolverType<sym>*>(solver)->clear_solver(); }},
parameters_changed{[](void* solver, bool changed) {
reinterpret_cast<MathSolverType<sym>*>(solver)->parameters_changed(changed);
}} {}

std::add_pointer_t<void*(std::shared_ptr<MathModelTopology const> const&)> create;
std::add_pointer_t<void(void const*)> destroy;
std::add_pointer_t<void*(void const*)> copy;
std::add_pointer_t<SolverOutput<sym>(void*, PowerFlowInput<sym> const&, double, Idx, CalculationInfo&,
CalculationMethod, YBus<sym> const&)>
run_power_flow;
std::add_pointer_t<SolverOutput<sym>(void*, StateEstimationInput<sym> const&, double, Idx, CalculationInfo&,
CalculationMethod, YBus<sym> const&)>
run_state_estimation;
std::add_pointer_t<ShortCircuitSolverOutput<sym>(void*, ShortCircuitInput const&, CalculationInfo&,
CalculationMethod, YBus<sym> const&)>
run_short_circuit;
std::add_pointer_t<void(void*)> clear_solver;
std::add_pointer_t<void(void*, bool)> parameters_changed;
};

template <template <symmetry_tag> class MathSolverType>
constexpr MathSolverDispatcher(math_solver_tag<MathSolverType>)
: sym_config{math_solver_tag<MathSolverType>{}}, asym_config{math_solver_tag<MathSolverType>{}} {}

template <symmetry_tag sym> MathSolverDispatcherConfig<sym> const& get_dispather_config() const {
if constexpr (is_symmetric_v<sym>) {
return sym_config;
} else {
return asym_config;
}
}
MathSolverDispatcherConfig<symmetric_t> sym_config;
MathSolverDispatcherConfig<asymmetric_t> asym_config;
};

template <symmetry_tag sym> class MathSolverProxy {
public:
explicit MathSolverProxy(MathSolverDispatcher const* dispatcher,
std::shared_ptr<MathModelTopology const> const& topo_ptr)
: dispatcher_{dispatcher},
solver_{dispatcher_->get_dispather_config<sym>().create(topo_ptr),
dispatcher_->get_dispather_config<sym>().destroy} {}
MathSolverProxy(MathSolverProxy const& other)
: dispatcher_{other.dispatcher_},
solver_{dispatcher_->get_dispather_config<sym>().copy(other.get_ptr()),
dispatcher_->get_dispather_config<sym>().destroy} {}
MathSolverProxy& operator=(MathSolverProxy const& other) {
if (this != &other) {
solver_.reset();
dispatcher_ = other.dispatcher_;
solver_ = std::unique_ptr<void, std::add_pointer_t<void(void const*)>>{
dispatcher_->get_dispather_config<sym>().copy(other.get_ptr()),
dispatcher_->get_dispather_config<sym>().destroy};
}
return *this;
}
MathSolverProxy(MathSolverProxy&& other) noexcept = default;
MathSolverProxy& operator=(MathSolverProxy&& other) noexcept = default;
~MathSolverProxy() = default;

SolverOutput<sym> run_power_flow(PowerFlowInput<sym> const& input, double err_tol, Idx max_iter,
CalculationInfo& calculation_info, CalculationMethod calculation_method,
YBus<sym> const& y_bus) {
return dispatcher_->get_dispather_config<sym>().run_power_flow(get_ptr(), input, err_tol, max_iter,
calculation_info, calculation_method, y_bus);
}

SolverOutput<sym> run_state_estimation(StateEstimationInput<sym> const& input, double err_tol, Idx max_iter,
CalculationInfo& calculation_info, CalculationMethod calculation_method,
YBus<sym> const& y_bus) {
return dispatcher_->get_dispather_config<sym>().run_state_estimation(
get_ptr(), input, err_tol, max_iter, calculation_info, calculation_method, y_bus);
}

ShortCircuitSolverOutput<sym> run_short_circuit(ShortCircuitInput const& input, CalculationInfo& calculation_info,
CalculationMethod calculation_method, YBus<sym> const& y_bus) {
return dispatcher_->get_dispather_config<sym>().run_short_circuit(get_ptr(), input, calculation_info,
calculation_method, y_bus);
}

void clear_solver() { dispatcher_->get_dispather_config<sym>().clear_solver(get_ptr()); }

void parameters_changed(bool changed) {
dispatcher_->get_dispather_config<sym>().parameters_changed(get_ptr(), changed);
}

private:
MathSolverDispatcher const* dispatcher_{};
std::unique_ptr<void, std::add_pointer_t<void(void const*)>> solver_;

void* get_ptr() { return solver_.get(); }
void const* get_ptr() const { return solver_.get(); }
};

} // namespace math_solver

template <symmetry_tag sym> using MathSolverProxy = math_solver::MathSolverProxy<sym>;

using MathSolverDispatcher = math_solver::MathSolverDispatcher;

} // namespace power_grid_model
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,6 @@ class NewtonRaphsonPFSolver : public IterativePFSolver<sym_type, NewtonRaphsonPF
}
};

template class NewtonRaphsonPFSolver<symmetric_t>;
template class NewtonRaphsonPFSolver<asymmetric_t>;

} // namespace newton_raphson_pf

using newton_raphson_pf::NewtonRaphsonPFSolver;
Expand Down
Loading
Loading