Skip to content

Commit c087315

Browse files
committed
slightly improve flexibility and perf of modular multiplicative inverse
1 parent 2661739 commit c087315

File tree

3 files changed

+130
-40
lines changed

3 files changed

+130
-40
lines changed

modular_arithmetic/include/hurchalla/modular_arithmetic/detail/impl_modular_multiplicative_inverse.h

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,18 @@ namespace hurchalla { namespace detail {
2323
// note: uses a static member function to disallow ADL.
2424
struct impl_modular_multiplicative_inverse {
2525
template <typename T>
26-
HURCHALLA_FORCE_INLINE static T call(T val, T modulus)
26+
HURCHALLA_FORCE_INLINE static T call(T val, T modulus, T& gcd)
2727
{
2828
static_assert(ut_numeric_limits<T>::is_integer, "");
2929
static_assert(!(ut_numeric_limits<T>::is_signed), "");
3030
// I decided not to support modulus<=1, since it's not likely to be used and
3131
// it complicates the return type and adds conditional branches.
3232
HPBC_PRECONDITION2(modulus > 1);
3333

34-
// POSTCONDITION: Returns 0 if the inverse doesn't exist. Otherwise returns
34+
// POSTCONDITION1: Returns 0 if the inverse doesn't exist. Otherwise returns
3535
// the inverse (which is never 0, given that modulus>1).
36+
// POSTCONDITION2: Sets gcd to the greatest common divisor of val and
37+
// modulus. Note that if the inverse exists, we will get gcd == 1.
3638

3739
using U = typename safely_promote_unsigned<T>::type;
3840
using S = typename extensible_make_signed<U>::type;
@@ -42,32 +44,36 @@ struct impl_modular_multiplicative_inverse {
4244
// calculating only what is needed for the modular multiplicative inverse.
4345
S y1=0;
4446
U a1=modulus;
45-
{
46-
S y0=1;
47-
U a2=val;
48-
U q=0;
49-
while (a2 != 0) {
50-
S y2 = static_cast<S>(y0 - static_cast<S>(q)*y1);
51-
y0=y1;
52-
y1=y2;
53-
U a0=a1;
54-
a1=a2;
47+
S y0=1;
48+
U a2=val;
49+
U q=0;
50+
while (a2 > 1) {
51+
S y2 = static_cast<S>(y0 - static_cast<S>(q)*y1);
52+
y0=y1;
53+
y1=y2;
54+
U a0=a1;
55+
a1=a2;
5556

56-
q = static_cast<U>(a0/a1);
57-
a2 = static_cast<U>(a0 - q*a1);
58-
}
57+
q = static_cast<U>(a0/a1);
58+
a2 = static_cast<U>(a0 - q*a1);
5959
}
60-
S y = y1;
61-
U gcd = a1;
62-
if (gcd == 1) {
60+
HPBC_ASSERT2(a1 > 1);
61+
62+
if (a2 == 1) {
63+
gcd = 1;
64+
S y = static_cast<S>(y0 - static_cast<S>(q)*y1);
6365
// inv = (y<0) ? y+modulus : y
6466
U inv = ::hurchalla::conditional_select(y<0,
65-
static_cast<U>(static_cast<U>(y)+modulus),
66-
static_cast<U>(y));
67+
static_cast<U>(static_cast<U>(y)+modulus),
68+
static_cast<U>(y));
6769
HPBC_POSTCONDITION2(inv < modulus);
6870
return static_cast<T>(inv);
69-
} else
71+
}
72+
else {
73+
gcd = static_cast<T>(a1);
74+
HPBC_ASSERT2(gcd > 1);
7075
return 0;
76+
}
7177
}
7278
};
7379

modular_arithmetic/include/hurchalla/modular_arithmetic/modular_multiplicative_inverse.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,43 @@
1212
#include "hurchalla/modular_arithmetic/detail/impl_modular_multiplicative_inverse.h"
1313
#include "hurchalla/modular_arithmetic/modular_multiplication.h"
1414
#include "hurchalla/util/traits/ut_numeric_limits.h"
15+
#include "hurchalla/util/compiler_macros.h"
1516
#include "hurchalla/util/programming_by_contract.h"
1617

1718
namespace hurchalla {
1819

1920

21+
// Returns the modular multiplicative inverse of 'a', mod the modulus.
22+
// Also assigns the gcd of 'a' and modulus to the reference parameter gcd.
23+
//
2024
// Note: Calling with a < modulus slightly improves performance.
2125
// [The multiplicative inverse is an integer > 0 and < modulus, such that
2226
// a * multiplicative_inverse == 1 (mod modulus). It is a unique number,
2327
// but it exists if and only if 'a' and 'modulus' are coprime.]
2428
template <typename T>
25-
T modular_multiplicative_inverse(T a, T modulus)
29+
T modular_multiplicative_inverse(T a, T modulus, T& gcd)
2630
{
2731
static_assert(ut_numeric_limits<T>::is_integer, "");
2832
static_assert(!(ut_numeric_limits<T>::is_signed), "");
2933
HPBC_PRECONDITION(modulus > 1);
3034

31-
T inverse = detail::impl_modular_multiplicative_inverse::call(a, modulus);
35+
T inv = detail::impl_modular_multiplicative_inverse::call(a, modulus, gcd);
3236

33-
HPBC_POSTCONDITION(inverse < modulus);
37+
HPBC_POSTCONDITION(inv < modulus);
3438
//POSTCONDITION: Returns 0 if the inverse does not exist. Otherwise returns
3539
// the value of the inverse (which is never 0, given that modulus>1).
36-
HPBC_POSTCONDITION(inverse == 0 ||
40+
HPBC_POSTCONDITION(inv == 0 ||
3741
::hurchalla::modular_multiplication_prereduced_inputs(
38-
static_cast<T>(a % modulus), inverse, modulus) == 1);
39-
return inverse;
42+
static_cast<T>(a % modulus), inv, modulus) == 1);
43+
return inv;
44+
}
45+
46+
// Same as the above function, except that it omits the gcd reference parameter.
47+
template <typename T>
48+
HURCHALLA_FORCE_INLINE T modular_multiplicative_inverse(T a, T modulus)
49+
{
50+
T gcd; // ignored
51+
return modular_multiplicative_inverse(a, modulus, gcd);
4052
}
4153

4254

test/modular_arithmetic/test_modular_multiplicative_inverse.cpp

Lines changed: 85 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,21 @@ void exhaustive_test_uint8_t()
5959
{
6060
for (std::uint8_t modulus=255; modulus>1; --modulus) {
6161
for (std::uint8_t a=0; a<modulus; ++a) {
62-
std::uint8_t inv = hc::modular_multiplicative_inverse(a, modulus);
62+
std::uint8_t g;
63+
std::uint8_t inv = hc::modular_multiplicative_inverse(a, modulus,g);
64+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
6365
if (inv == 0)
64-
EXPECT_TRUE(1 < testmmi::gcd(a, modulus));
65-
else
66+
EXPECT_TRUE(g > 1);
67+
else {
68+
EXPECT_TRUE(g == 1);
6669
EXPECT_TRUE(static_cast<std::uint8_t>(1) ==
6770
hc::modular_multiplication_prereduced_inputs(a, inv,
6871
modulus));
72+
}
73+
if (g > 1)
74+
EXPECT_TRUE(inv == 0);
75+
else
76+
EXPECT_TRUE(inv > 0);
6977
}
7078
}
7179
}
@@ -74,25 +82,34 @@ void exhaustive_test_uint8_t()
7482
template <typename T>
7583
void test_modulus(T modulus)
7684
{
85+
T g;
86+
7787
T a = 0;
7888
EXPECT_TRUE(static_cast<T>(0) ==
79-
hc::modular_multiplicative_inverse(a, modulus));
89+
hc::modular_multiplicative_inverse(a, modulus, g));
90+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
91+
8092
a = 1;
8193
EXPECT_TRUE(static_cast<T>(1) ==
82-
hc::modular_multiplicative_inverse(a, modulus));
94+
hc::modular_multiplicative_inverse(a, modulus, g));
95+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
96+
8397
a = modulus;
8498
EXPECT_TRUE(static_cast<T>(0) ==
85-
hc::modular_multiplicative_inverse(a, modulus));
99+
hc::modular_multiplicative_inverse(a, modulus, g));
100+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
86101

87102
T tmax = hc::ut_numeric_limits<T>::max();
88103
if (modulus < tmax) {
89104
a = static_cast<T>(modulus + 1);
90105
EXPECT_TRUE(static_cast<T>(1) ==
91-
hc::modular_multiplicative_inverse(a, modulus));
106+
hc::modular_multiplicative_inverse(a, modulus, g));
107+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
92108
}
93109

94110
a = 2;
95-
T inverse = hc::modular_multiplicative_inverse(a, modulus);
111+
T inverse = hc::modular_multiplicative_inverse(a, modulus, g);
112+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
96113
if (inverse == 0)
97114
EXPECT_TRUE(1 < testmmi::gcd(a, modulus));
98115
else
@@ -101,7 +118,8 @@ void test_modulus(T modulus)
101118
static_cast<T>(a % modulus), inverse, modulus));
102119

103120
a = 3;
104-
inverse = hc::modular_multiplicative_inverse(a, modulus);
121+
inverse = hc::modular_multiplicative_inverse(a, modulus, g);
122+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
105123
if (inverse == 0)
106124
EXPECT_TRUE(1 < testmmi::gcd(a, modulus));
107125
else
@@ -110,28 +128,32 @@ void test_modulus(T modulus)
110128
static_cast<T>(a % modulus), inverse, modulus));
111129

112130
a = static_cast<T>(modulus - 1);
113-
inverse = hc::modular_multiplicative_inverse(a, modulus);
131+
inverse = hc::modular_multiplicative_inverse(a, modulus, g);
132+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
114133
EXPECT_TRUE(static_cast<T>(1) ==
115134
hc::modular_multiplication_prereduced_inputs(a, inverse, modulus));
116135

117136
a = static_cast<T>(modulus - 2);
118-
inverse = hc::modular_multiplicative_inverse(a, modulus);
137+
inverse = hc::modular_multiplicative_inverse(a, modulus, g);
138+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
119139
if (inverse == 0)
120140
EXPECT_TRUE(1 < testmmi::gcd(a, modulus));
121141
else
122142
EXPECT_TRUE(static_cast<T>(1) ==
123143
hc::modular_multiplication_prereduced_inputs(a, inverse, modulus));
124144

125145
a = static_cast<T>(modulus/2);
126-
inverse = hc::modular_multiplicative_inverse(a, modulus);
146+
inverse = hc::modular_multiplicative_inverse(a, modulus, g);
147+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
127148
if (inverse == 0)
128149
EXPECT_TRUE(1 < testmmi::gcd(a, modulus));
129150
else
130151
EXPECT_TRUE(static_cast<T>(1) ==
131152
hc::modular_multiplication_prereduced_inputs(a, inverse, modulus));
132153

133154
a++;
134-
inverse = hc::modular_multiplicative_inverse(a, modulus);
155+
inverse = hc::modular_multiplicative_inverse(a, modulus, g);
156+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
135157
if (inverse == 0)
136158
EXPECT_TRUE(1 < testmmi::gcd(a, modulus));
137159
else
@@ -144,26 +166,52 @@ void test_modulus(T modulus)
144166
template <typename T>
145167
void test_modular_multiplicative_inverse()
146168
{
169+
T g;
170+
147171
// test with a few basic examples first
148172
T modulus = 13;
149173
T a = 5;
150174
EXPECT_TRUE(static_cast<T>(8) ==
151175
hc::modular_multiplicative_inverse(a, modulus));
176+
EXPECT_TRUE(static_cast<T>(8) ==
177+
hc::modular_multiplicative_inverse(a, modulus, g));
178+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
179+
152180
a = 7;
153181
EXPECT_TRUE(static_cast<T>(2) ==
154182
hc::modular_multiplicative_inverse(a, modulus));
183+
EXPECT_TRUE(static_cast<T>(2) ==
184+
hc::modular_multiplicative_inverse(a, modulus, g));
185+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
186+
155187
a = 4;
156188
EXPECT_TRUE(static_cast<T>(10) ==
157189
hc::modular_multiplicative_inverse(a, modulus));
190+
EXPECT_TRUE(static_cast<T>(10) ==
191+
hc::modular_multiplicative_inverse(a, modulus, g));
192+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
193+
158194
a = 17;
159195
EXPECT_TRUE(static_cast<T>(10) ==
160196
hc::modular_multiplicative_inverse(a, modulus));
197+
EXPECT_TRUE(static_cast<T>(10) ==
198+
hc::modular_multiplicative_inverse(a, modulus, g));
199+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
200+
161201
a = 1;
162202
EXPECT_TRUE(static_cast<T>(1) ==
163203
hc::modular_multiplicative_inverse(a, modulus));
204+
EXPECT_TRUE(static_cast<T>(1) ==
205+
hc::modular_multiplicative_inverse(a, modulus, g));
206+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
207+
164208
a = 14;
165209
EXPECT_TRUE(static_cast<T>(1) ==
166210
hc::modular_multiplicative_inverse(a, modulus));
211+
EXPECT_TRUE(static_cast<T>(1) ==
212+
hc::modular_multiplicative_inverse(a, modulus, g));
213+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
214+
167215

168216

169217
// modular_multiplicative_inverse() indicates the inverse doesn't exist by
@@ -172,12 +220,24 @@ void test_modular_multiplicative_inverse()
172220
modulus = 21; // a modulus of 21 shares the factor 3 with a.
173221
EXPECT_TRUE(static_cast<T>(0) ==
174222
hc::modular_multiplicative_inverse(a, modulus));
223+
EXPECT_TRUE(static_cast<T>(0) ==
224+
hc::modular_multiplicative_inverse(a, modulus, g));
225+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
226+
175227
a = 0;
176228
EXPECT_TRUE(static_cast<T>(0) ==
177229
hc::modular_multiplicative_inverse(a, modulus));
230+
EXPECT_TRUE(static_cast<T>(0) ==
231+
hc::modular_multiplicative_inverse(a, modulus, g));
232+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
233+
178234
a = 1;
179235
EXPECT_TRUE(static_cast<T>(1) ==
180236
hc::modular_multiplicative_inverse(a, modulus));
237+
EXPECT_TRUE(static_cast<T>(1) ==
238+
hc::modular_multiplicative_inverse(a, modulus, g));
239+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
240+
181241

182242
a = 7;
183243
modulus = 16;
@@ -204,12 +264,24 @@ void test_modular_multiplicative_inverse()
204264
a = 0;
205265
EXPECT_TRUE(static_cast<T>(0) ==
206266
hc::modular_multiplicative_inverse(a, modulus));
267+
EXPECT_TRUE(static_cast<T>(0) ==
268+
hc::modular_multiplicative_inverse(a, modulus, g));
269+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
270+
207271
a = 1;
208272
EXPECT_TRUE(static_cast<T>(1) ==
209273
hc::modular_multiplicative_inverse(a, modulus));
274+
EXPECT_TRUE(static_cast<T>(1) ==
275+
hc::modular_multiplicative_inverse(a, modulus, g));
276+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
277+
210278
a = 5;
211279
EXPECT_TRUE(static_cast<T>(1) ==
212280
hc::modular_multiplicative_inverse(a, modulus));
281+
EXPECT_TRUE(static_cast<T>(1) ==
282+
hc::modular_multiplicative_inverse(a, modulus, g));
283+
EXPECT_TRUE(g == testmmi::gcd(a, modulus));
284+
213285

214286
modulus = hc::ut_numeric_limits<T>::max();
215287
test_modulus(modulus);

0 commit comments

Comments
 (0)