1use crate::{algorithms, Uint};
2
3impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
11 #[inline]
18 #[must_use]
19 pub fn reduce_mod(mut self, modulus: Self) -> Self {
20 if modulus == Self::ZERO {
21 return Self::ZERO;
22 }
23 if self >= modulus {
24 self %= modulus;
25 }
26 self
27 }
28
29 #[inline]
33 #[must_use]
34 pub fn add_mod(self, rhs: Self, modulus: Self) -> Self {
35 let lhs = self.reduce_mod(modulus);
37 let rhs = rhs.reduce_mod(modulus);
38
39 let (mut result, overflow) = lhs.overflowing_add(rhs);
41 if overflow || result >= modulus {
42 result -= modulus;
43 }
44 result
45 }
46
47 #[inline]
54 #[must_use]
55 pub fn mul_mod(self, rhs: Self, mut modulus: Self) -> Self {
56 if modulus == Self::ZERO {
57 return Self::ZERO;
58 }
59
60 let mut product = [[0u64; 2]; LIMBS];
63 let product_len = crate::nlimbs(2 * BITS);
64 debug_assert!(2 * LIMBS >= product_len);
65 let product = unsafe {
67 core::slice::from_raw_parts_mut(product.as_mut_ptr().cast::<u64>(), product_len)
68 };
69
70 let overflow = algorithms::addmul(product, self.as_limbs(), rhs.as_limbs());
72 debug_assert!(!overflow);
73
74 algorithms::div(product, &mut modulus.limbs);
77
78 modulus
79 }
80
81 #[inline]
85 #[must_use]
86 pub fn pow_mod(mut self, mut exp: Self, modulus: Self) -> Self {
87 if modulus == Self::ZERO || modulus <= Self::from(1) {
88 return Self::ZERO;
90 }
91
92 let mut result = Self::from(1);
94 while exp > Self::ZERO {
95 if exp.limbs[0] & 1 == 1 {
97 result = result.mul_mod(self, modulus);
98 }
99
100 self = self.mul_mod(self, modulus);
102 exp >>= 1;
103 }
104 result
105 }
106
107 #[inline]
111 #[must_use]
112 pub fn inv_mod(self, modulus: Self) -> Option<Self> {
113 algorithms::inv_mod(self, modulus)
114 }
115
116 #[inline]
155 #[must_use]
156 #[cfg(feature = "alloc")] pub fn mul_redc(self, other: Self, modulus: Self, inv: u64) -> Self {
158 if BITS == 0 {
159 return Self::ZERO;
160 }
161 assert_eq!(inv.wrapping_mul(modulus.limbs[0]), u64::MAX);
162 let mut result = Self::ZERO;
163 algorithms::mul_redc(
164 self.as_limbs(),
165 other.as_limbs(),
166 &mut result.limbs,
167 modulus.as_limbs(),
168 inv,
169 );
170 debug_assert!(result < modulus);
171 result
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use crate::{aliases::U64, const_for, nlimbs};
179 use core::cmp::min;
180 use proptest::{prop_assume, proptest, test_runner::Config};
181
182 #[test]
183 fn test_commutative() {
184 const_for!(BITS in SIZES {
185 const LIMBS: usize = nlimbs(BITS);
186 type U = Uint<BITS, LIMBS>;
187 proptest!(|(a: U, b: U, m: U)| {
188 assert_eq!(a.mul_mod(b, m), b.mul_mod(a, m));
189 });
190 });
191 }
192
193 #[test]
194 fn test_associative() {
195 const_for!(BITS in SIZES {
196 const LIMBS: usize = nlimbs(BITS);
197 type U = Uint<BITS, LIMBS>;
198 proptest!(|(a: U, b: U, c: U, m: U)| {
199 assert_eq!(a.mul_mod(b.mul_mod(c, m), m), a.mul_mod(b, m).mul_mod(c, m));
200 });
201 });
202 }
203
204 #[test]
205 fn test_distributive() {
206 const_for!(BITS in SIZES {
207 const LIMBS: usize = nlimbs(BITS);
208 type U = Uint<BITS, LIMBS>;
209 proptest!(|(a: U, b: U, c: U, m: U)| {
210 assert_eq!(a.mul_mod(b.add_mod(c, m), m), a.mul_mod(b, m).add_mod(a.mul_mod(c, m), m));
211 });
212 });
213 }
214
215 #[test]
216 fn test_add_identity() {
217 const_for!(BITS in NON_ZERO {
218 const LIMBS: usize = nlimbs(BITS);
219 type U = Uint<BITS, LIMBS>;
220 proptest!(|(value: U, m: U)| {
221 assert_eq!(value.add_mod(U::from(0), m), value.reduce_mod(m));
222 });
223 });
224 }
225
226 #[test]
227 fn test_mul_identity() {
228 const_for!(BITS in NON_ZERO {
229 const LIMBS: usize = nlimbs(BITS);
230 type U = Uint<BITS, LIMBS>;
231 proptest!(|(value: U, m: U)| {
232 assert_eq!(value.mul_mod(U::from(0), m), U::ZERO);
233 assert_eq!(value.mul_mod(U::from(1), m), value.reduce_mod(m));
234 });
235 });
236 }
237
238 #[test]
239 fn test_pow_identity() {
240 const_for!(BITS in NON_ZERO {
241 const LIMBS: usize = nlimbs(BITS);
242 type U = Uint<BITS, LIMBS>;
243 proptest!(|(a: U, m: U)| {
244 assert_eq!(a.pow_mod(U::from(0), m), U::from(1).reduce_mod(m));
245 assert_eq!(a.pow_mod(U::from(1), m), a.reduce_mod(m));
246 });
247 });
248 }
249
250 #[test]
251 fn test_pow_rules() {
252 const_for!(BITS in NON_ZERO {
253 const LIMBS: usize = nlimbs(BITS);
254 type U = Uint<BITS, LIMBS>;
255 let mut config = Config::default();
257 config.cases = min(config.cases, if BITS > 500 { 1 } else { 3 });
259 proptest!(config, |(a: U, b: U, c: U, m: U)| {
260 assert_eq!(a.mul_mod(b, m).pow_mod(c, m), a.pow_mod(c, m).mul_mod(b.pow_mod(c, m), m));
263 });
264 });
265 }
266
267 #[test]
268 fn test_inv() {
269 const_for!(BITS in NON_ZERO {
270 const LIMBS: usize = nlimbs(BITS);
271 type U = Uint<BITS, LIMBS>;
272 let mut config = Config::default();
274 config.cases = min(config.cases, if BITS > 500 { 6 } else { 20 });
275 proptest!(config, |(a: U, m: U)| {
276 if let Some(inv) = a.inv_mod(m) {
277 assert_eq!(a.mul_mod(inv, m), U::from(1));
278 }
279 });
280 });
281 }
282
283 #[test]
284 fn test_mul_redc() {
285 const_for!(BITS in NON_ZERO if (BITS >= 16) {
286 const LIMBS: usize = nlimbs(BITS);
287 type U = Uint<BITS, LIMBS>;
288 proptest!(|(a: U, b: U, m: U)| {
289 prop_assume!(m >= U::from(2));
290 if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
291 let inv = (-inv).as_limbs()[0];
292
293 let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
294 let ar = a.mul_mod(r, m);
295 let br = b.mul_mod(r, m);
296 let expected = a.mul_mod(b, m).mul_mod(r, m);
299
300 assert_eq!(ar.mul_redc(br, m, inv), expected);
301 }
302 });
303 });
304 }
305}