ruint/
modular.rs

1use crate::{algorithms, Uint};
2
3// FEATURE: sub_mod, neg_mod, inv_mod, div_mod, root_mod
4// See <https://en.wikipedia.org/wiki/Cipolla's_algorithm>
5// FEATURE: mul_mod_redc
6// and maybe barrett
7// See also <https://static1.squarespace.com/static/61f7cacf2d7af938cad5b81c/t/62deb4e0c434f7134c2730ee/1658762465114/modular_multiplication.pdf>
8// FEATURE: Modular wrapper class, like Wrapping.
9
10impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
11    /// ⚠️ Compute $\mod{\mathtt{self}}_{\mathtt{modulus}}$.
12    ///
13    /// **Warning.** This function is not part of the stable API.
14    ///
15    /// Returns zero if the modulus is zero.
16    // FEATURE: Reduce larger bit-sizes to smaller ones.
17    #[inline]
18    #[must_use]
19    pub fn reduce_mod(mut self, modulus: Self) -> Self {
20        if modulus == Self::ZERO {
21            return Self::ZERO;
22        }
23        if self >= modulus {
24            self %= modulus;
25        }
26        self
27    }
28
29    /// Compute $\mod{\mathtt{self} + \mathtt{rhs}}_{\mathtt{modulus}}$.
30    ///
31    /// Returns zero if the modulus is zero.
32    #[inline]
33    #[must_use]
34    pub fn add_mod(self, rhs: Self, modulus: Self) -> Self {
35        // Reduce inputs
36        let lhs = self.reduce_mod(modulus);
37        let rhs = rhs.reduce_mod(modulus);
38
39        // Compute the sum and conditionally subtract modulus once.
40        let (mut result, overflow) = lhs.overflowing_add(rhs);
41        if overflow || result >= modulus {
42            result -= modulus;
43        }
44        result
45    }
46
47    /// Compute $\mod{\mathtt{self} ⋅ \mathtt{rhs}}_{\mathtt{modulus}}$.
48    ///
49    /// Returns zero if the modulus is zero.
50    ///
51    /// See [`mul_redc`](Self::mul_redc) for a faster variant at the cost of
52    /// some pre-computation.
53    #[inline]
54    #[must_use]
55    pub fn mul_mod(self, rhs: Self, mut modulus: Self) -> Self {
56        if modulus == Self::ZERO {
57            return Self::ZERO;
58        }
59
60        // Allocate at least `nlimbs(2 * BITS)` limbs to store the product. This array
61        // casting is a workaround for `generic_const_exprs` not being stable.
62        let mut product = [[0u64; 2]; LIMBS];
63        let product_len = crate::nlimbs(2 * BITS);
64        debug_assert!(2 * LIMBS >= product_len);
65        // SAFETY: `[[u64; 2]; LIMBS] == [u64; 2 * LIMBS] >= [u64; nlimbs(2 * BITS)]`.
66        let product = unsafe {
67            core::slice::from_raw_parts_mut(product.as_mut_ptr().cast::<u64>(), product_len)
68        };
69
70        // Compute full product.
71        let overflow = algorithms::addmul(product, self.as_limbs(), rhs.as_limbs());
72        debug_assert!(!overflow);
73
74        // Compute modulus using `div_rem`.
75        // This stores the remainder in the divisor, `modulus`.
76        algorithms::div(product, &mut modulus.limbs);
77
78        modulus
79    }
80
81    /// Compute $\mod{\mathtt{self}^{\mathtt{rhs}}}_{\mathtt{modulus}}$.
82    ///
83    /// Returns zero if the modulus is zero.
84    #[inline]
85    #[must_use]
86    pub fn pow_mod(mut self, mut exp: Self, modulus: Self) -> Self {
87        if modulus == Self::ZERO || modulus <= Self::from(1) {
88            // Also covers Self::BITS == 0
89            return Self::ZERO;
90        }
91
92        // Exponentiation by squaring
93        let mut result = Self::from(1);
94        while exp > Self::ZERO {
95            // Multiply by base
96            if exp.limbs[0] & 1 == 1 {
97                result = result.mul_mod(self, modulus);
98            }
99
100            // Square base
101            self = self.mul_mod(self, modulus);
102            exp >>= 1;
103        }
104        result
105    }
106
107    /// Compute $\mod{\mathtt{self}^{-1}}_{\mathtt{modulus}}$.
108    ///
109    /// Returns `None` if the inverse does not exist.
110    #[inline]
111    #[must_use]
112    pub fn inv_mod(self, modulus: Self) -> Option<Self> {
113        algorithms::inv_mod(self, modulus)
114    }
115
116    /// Montgomery multiplication.
117    ///
118    /// Computes
119    ///
120    /// $$
121    /// \mod{\frac{\mathtt{self} ⋅ \mathtt{other}}{ 2^{64 ·
122    /// \mathtt{LIMBS}}}}_{\mathtt{modulus}} $$
123    ///
124    /// This is useful because it can be computed notably faster than
125    /// [`mul_mod`](Self::mul_mod). Many computations can be done by
126    /// pre-multiplying values with $R = 2^{64 · \mathtt{LIMBS}}$
127    /// and then using [`mul_redc`](Self::mul_redc) instead of
128    /// [`mul_mod`](Self::mul_mod).
129    ///
130    /// For this algorithm to work, it needs an extra parameter `inv` which must
131    /// be set to
132    ///
133    /// $$
134    /// \mathtt{inv} = \mod{\frac{-1}{\mathtt{modulus}} }_{2^{64}}
135    /// $$
136    ///
137    /// The `inv` value only exists for odd values of `modulus`. It can be
138    /// computed using [`inv_ring`](Self::inv_ring) from `U64`.
139    ///
140    /// ```
141    /// # use ruint::{uint, Uint, aliases::*};
142    /// # uint!{
143    /// # let modulus = 21888242871839275222246405745257275088548364400416034343698204186575808495617_U256;
144    /// let inv = U64::wrapping_from(modulus).inv_ring().unwrap().wrapping_neg().to();
145    /// let prod = 5_U256.mul_redc(6_U256, modulus, inv);
146    /// # assert_eq!(inv.wrapping_mul(modulus.wrapping_to()), u64::MAX);
147    /// # assert_eq!(inv, 0xc2e1f593efffffff);
148    /// # }
149    /// ```
150    ///
151    /// # Panics
152    ///
153    /// Panics if `inv` is not correct.
154    #[inline]
155    #[must_use]
156    #[cfg(feature = "alloc")] // TODO: Make mul_redc alloc-free
157    pub fn mul_redc(self, other: Self, modulus: Self, inv: u64) -> Self {
158        if BITS == 0 {
159            return Self::ZERO;
160        }
161        assert_eq!(inv.wrapping_mul(modulus.limbs[0]), u64::MAX);
162        let mut result = Self::ZERO;
163        algorithms::mul_redc(
164            self.as_limbs(),
165            other.as_limbs(),
166            &mut result.limbs,
167            modulus.as_limbs(),
168            inv,
169        );
170        debug_assert!(result < modulus);
171        result
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use crate::{aliases::U64, const_for, nlimbs};
179    use core::cmp::min;
180    use proptest::{prop_assume, proptest, test_runner::Config};
181
182    #[test]
183    fn test_commutative() {
184        const_for!(BITS in SIZES {
185            const LIMBS: usize = nlimbs(BITS);
186            type U = Uint<BITS, LIMBS>;
187            proptest!(|(a: U, b: U, m: U)| {
188                assert_eq!(a.mul_mod(b, m), b.mul_mod(a, m));
189            });
190        });
191    }
192
193    #[test]
194    fn test_associative() {
195        const_for!(BITS in SIZES {
196            const LIMBS: usize = nlimbs(BITS);
197            type U = Uint<BITS, LIMBS>;
198            proptest!(|(a: U, b: U, c: U, m: U)| {
199                assert_eq!(a.mul_mod(b.mul_mod(c, m), m), a.mul_mod(b, m).mul_mod(c, m));
200            });
201        });
202    }
203
204    #[test]
205    fn test_distributive() {
206        const_for!(BITS in SIZES {
207            const LIMBS: usize = nlimbs(BITS);
208            type U = Uint<BITS, LIMBS>;
209            proptest!(|(a: U, b: U, c: U, m: U)| {
210                assert_eq!(a.mul_mod(b.add_mod(c, m), m), a.mul_mod(b, m).add_mod(a.mul_mod(c, m), m));
211            });
212        });
213    }
214
215    #[test]
216    fn test_add_identity() {
217        const_for!(BITS in NON_ZERO {
218            const LIMBS: usize = nlimbs(BITS);
219            type U = Uint<BITS, LIMBS>;
220            proptest!(|(value: U, m: U)| {
221                assert_eq!(value.add_mod(U::from(0), m), value.reduce_mod(m));
222            });
223        });
224    }
225
226    #[test]
227    fn test_mul_identity() {
228        const_for!(BITS in NON_ZERO {
229            const LIMBS: usize = nlimbs(BITS);
230            type U = Uint<BITS, LIMBS>;
231            proptest!(|(value: U, m: U)| {
232                assert_eq!(value.mul_mod(U::from(0), m), U::ZERO);
233                assert_eq!(value.mul_mod(U::from(1), m), value.reduce_mod(m));
234            });
235        });
236    }
237
238    #[test]
239    fn test_pow_identity() {
240        const_for!(BITS in NON_ZERO {
241            const LIMBS: usize = nlimbs(BITS);
242            type U = Uint<BITS, LIMBS>;
243            proptest!(|(a: U, m: U)| {
244                assert_eq!(a.pow_mod(U::from(0), m), U::from(1).reduce_mod(m));
245                assert_eq!(a.pow_mod(U::from(1), m), a.reduce_mod(m));
246            });
247        });
248    }
249
250    #[test]
251    fn test_pow_rules() {
252        const_for!(BITS in NON_ZERO {
253            const LIMBS: usize = nlimbs(BITS);
254            type U = Uint<BITS, LIMBS>;
255            // TODO: Increase cases when perf is better.
256            let mut config = Config::default();
257            // BUG: Proptest still runs 5 cases even if we set it to 1.
258            config.cases = min(config.cases, if BITS > 500 { 1 } else { 3 });
259            proptest!(config, |(a: U, b: U, c: U, m: U)| {
260                // TODO: a^(b+c) = a^b * a^c. Which requires carmichael fn.
261                // TODO: (a^b)^c = a^(b * c). Which requires carmichael fn.
262                assert_eq!(a.mul_mod(b, m).pow_mod(c, m), a.pow_mod(c, m).mul_mod(b.pow_mod(c, m), m));
263            });
264        });
265    }
266
267    #[test]
268    fn test_inv() {
269        const_for!(BITS in NON_ZERO {
270            const LIMBS: usize = nlimbs(BITS);
271            type U = Uint<BITS, LIMBS>;
272            // TODO: Increase cases when perf is better.
273            let mut config = Config::default();
274            config.cases = min(config.cases, if BITS > 500 { 6 } else { 20 });
275            proptest!(config, |(a: U, m: U)| {
276                if let Some(inv) = a.inv_mod(m) {
277                    assert_eq!(a.mul_mod(inv, m), U::from(1));
278                }
279            });
280        });
281    }
282
283    #[test]
284    fn test_mul_redc() {
285        const_for!(BITS in NON_ZERO if (BITS >= 16) {
286            const LIMBS: usize = nlimbs(BITS);
287            type U = Uint<BITS, LIMBS>;
288            proptest!(|(a: U, b: U, m: U)| {
289                prop_assume!(m >= U::from(2));
290                if let Some(inv) = U64::from(m.as_limbs()[0]).inv_ring() {
291                    let inv = (-inv).as_limbs()[0];
292
293                    let r = U::from(2).pow_mod(U::from(64 * LIMBS), m);
294                    let ar = a.mul_mod(r, m);
295                    let br = b.mul_mod(r, m);
296                    // TODO: Test for larger (>= m) values of a, b.
297
298                    let expected = a.mul_mod(b, m).mul_mod(r, m);
299
300                    assert_eq!(ar.mul_redc(br, m, inv), expected);
301                }
302            });
303        });
304    }
305}