17
17
#include < ginkgo/core/distributed/matrix.hpp>
18
18
#include < ginkgo/core/matrix/csr.hpp>
19
19
#include < ginkgo/core/matrix/dense.hpp>
20
+ #include < ginkgo/core/matrix/identity.hpp>
20
21
21
22
#include " core/base/utils.hpp"
22
23
#include " core/config/config_helper.hpp"
23
24
#include " core/config/dispatch.hpp"
24
25
#include " core/distributed/helpers.hpp"
26
+ #include " core/matrix/csr_kernels.hpp"
25
27
26
28
27
29
namespace gko {
28
30
namespace experimental {
29
31
namespace distributed {
30
32
namespace preconditioner {
33
+ namespace {
34
+
35
+
36
+ GKO_REGISTER_OPERATION (row_wise_sum, csr::row_wise_sum);
37
+
38
+
39
+ }
31
40
32
41
33
42
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
@@ -36,7 +45,7 @@ Schwarz<ValueType, LocalIndexType, GlobalIndexType>::parse(
36
45
const config::pnode& config, const config::registry& context,
37
46
const config::type_descriptor& td_for_child)
38
47
{
39
- auto params = Schwarz<ValueType, LocalIndexType, GlobalIndexType> ::build ();
48
+ auto params = Schwarz::build ();
40
49
41
50
if (auto & obj = config.get (" generated_local_solver" )) {
42
51
params.with_generated_local_solver (
@@ -47,7 +56,9 @@ Schwarz<ValueType, LocalIndexType, GlobalIndexType>::parse(
47
56
gko::config::parse_or_get_factory<const LinOpFactory>(
48
57
obj, context, td_for_child));
49
58
}
50
-
59
+ if (auto & obj = config.get (" l1_smoother" )) {
60
+ params.with_l1_smoother (obj.get_boolean ());
61
+ }
51
62
return params;
52
63
}
53
64
@@ -76,7 +87,6 @@ template <typename VectorType>
76
87
void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::apply_dense_impl(
77
88
const VectorType* dense_b, VectorType* dense_x) const
78
89
{
79
- using Vector = matrix::Dense<ValueType>;
80
90
auto exec = this ->get_executor ();
81
91
if (this ->local_solver_ != nullptr ) {
82
92
this ->local_solver_ ->apply (gko::detail::get_local (dense_b),
@@ -130,14 +140,47 @@ void Schwarz<ValueType, LocalIndexType, GlobalIndexType>::generate(
130
140
" Requires either a generated solver or an solver factory" );
131
141
}
132
142
133
- if (parameters_.local_solver ) {
134
- this ->set_solver (gko::share (parameters_.local_solver ->generate (
135
- as<experimental::distributed::Matrix<
136
- ValueType, LocalIndexType, GlobalIndexType>>(system_matrix)
137
- ->get_local_matrix ())));
143
+ if (parameters_.generated_local_solver ) {
144
+ this ->set_solver (parameters_.generated_local_solver );
145
+ return ;
146
+ }
138
147
148
+ auto local_matrix =
149
+ as<Matrix<ValueType, LocalIndexType, GlobalIndexType>>(system_matrix)
150
+ ->get_local_matrix ();
151
+
152
+ if (parameters_.l1_smoother ) {
153
+ auto exec = this ->get_executor ();
154
+
155
+ using Csr = matrix::Csr<ValueType, LocalIndexType>;
156
+ auto local_matrix_copy = share (Csr::create (exec));
157
+ as<ConvertibleTo<Csr>>(local_matrix)->convert_to (local_matrix_copy);
158
+
159
+ auto non_local_matrix = copy_and_convert_to<Csr>(
160
+ exec, as<Matrix<ValueType, LocalIndexType, GlobalIndexType>>(
161
+ system_matrix)
162
+ ->get_non_local_matrix ());
163
+
164
+ array<ValueType> l1_diag_arr{exec, local_matrix->get_size ()[0 ]};
165
+
166
+ exec->run (make_row_wise_sum (non_local_matrix.get (), l1_diag_arr, true ));
167
+
168
+ // compute local_matrix_copy <- diag(l1) + local_matrix_copy
169
+ auto l1_diag = matrix::Diagonal<ValueType>::create (
170
+ exec, local_matrix->get_size ()[0 ], std::move (l1_diag_arr));
171
+ auto l1_diag_csr = Csr::create (exec);
172
+ l1_diag->move_to (l1_diag_csr);
173
+ auto id = matrix::Identity<ValueType>::create (
174
+ exec, local_matrix->get_size ()[0 ]);
175
+ auto one = initialize<matrix::Dense<ValueType>>(
176
+ {::gko::one<ValueType>()}, exec);
177
+ l1_diag_csr->apply (one, id, one, local_matrix_copy);
178
+
179
+ this ->set_solver (
180
+ gko::share (parameters_.local_solver ->generate (local_matrix_copy)));
139
181
} else {
140
- this ->set_solver (parameters_.generated_local_solver );
182
+ this ->set_solver (
183
+ gko::share (parameters_.local_solver ->generate (local_matrix)));
141
184
}
142
185
}
143
186
0 commit comments