ruint/algorithms/gcd/
mod.rs

1#![allow(clippy::module_name_repetitions)]
2
3// TODO: Make these algorithms work on limb slices.
4mod matrix;
5
6pub use self::matrix::Matrix as LehmerMatrix;
7use crate::Uint;
8use core::mem::swap;
9
10/// ⚠️ Lehmer's GCD algorithms.
11///
12/// **Warning.** This struct is not part of the stable API.
13///
14/// See [`gcd_extended`] for documentation.
15#[inline]
16#[must_use]
17pub fn gcd<const BITS: usize, const LIMBS: usize>(
18    mut a: Uint<BITS, LIMBS>,
19    mut b: Uint<BITS, LIMBS>,
20) -> Uint<BITS, LIMBS> {
21    if b > a {
22        swap(&mut a, &mut b);
23    }
24    while b != Uint::ZERO {
25        debug_assert!(a >= b);
26        let m = LehmerMatrix::from(a, b);
27        if m == LehmerMatrix::IDENTITY {
28            // Lehmer step failed to find a factor, which happens when
29            // the factor is very large. We do a regular Euclidean step, which
30            // will make a lot of progress since `q` will be large.
31            a %= b;
32            swap(&mut a, &mut b);
33        } else {
34            m.apply(&mut a, &mut b);
35        }
36    }
37    a
38}
39
40/// ⚠️ Lehmer's extended GCD.
41///
42/// **Warning.** This struct is not part of the stable API.
43///
44/// Returns `(gcd, x, y, sign)` such that `gcd = a * x + b * y`.
45///
46/// # Algorithm
47///
48/// A variation of Euclids algorithm where repeated 64-bit approximations are
49/// used to make rapid progress on.
50///
51/// See Jebelean (1994) "A Double-Digit Lehmer-Euclid Algorithm for Finding the
52/// GCD of Long Integers".
53///
54/// The function `lehmer_double` takes two `U256`'s and returns a 64-bit matrix.
55///
56/// The function `lehmer_update` updates state variables using this matrix. If
57/// the matrix makes no progress (because 64 bit precision is not enough) a full
58/// precision Euclid step is done, but this happens rarely.
59///
60/// See also `mpn_gcdext_lehmer_n` in GMP.
61/// <https://gmplib.org/repo/gmp-6.1/file/tip/mpn/generic/gcdext_lehmer.c#l146>
62#[inline]
63#[must_use]
64pub fn gcd_extended<const BITS: usize, const LIMBS: usize>(
65    mut a: Uint<BITS, LIMBS>,
66    mut b: Uint<BITS, LIMBS>,
67) -> (
68    Uint<BITS, LIMBS>,
69    Uint<BITS, LIMBS>,
70    Uint<BITS, LIMBS>,
71    bool,
72) {
73    if BITS == 0 {
74        return (Uint::ZERO, Uint::ZERO, Uint::ZERO, false);
75    }
76    let swapped = a < b;
77    if swapped {
78        swap(&mut a, &mut b);
79    }
80
81    // Initialize state matrix to identity.
82    let mut s0 = Uint::from(1);
83    let mut s1 = Uint::ZERO;
84    let mut t0 = Uint::ZERO;
85    let mut t1 = Uint::from(1);
86    let mut even = true;
87    while b != Uint::ZERO {
88        debug_assert!(a >= b);
89        let m = LehmerMatrix::from(a, b);
90        if m == LehmerMatrix::IDENTITY {
91            // Lehmer step failed to find a factor, which happens when
92            // the factor is very large. We do a regular Euclidean step, which
93            // will make a lot of progress since `q` will be large.
94            let q = a / b;
95            a -= q * b;
96            swap(&mut a, &mut b);
97            s0 -= q * s1;
98            swap(&mut s0, &mut s1);
99            t0 -= q * t1;
100            swap(&mut t0, &mut t1);
101            even = !even;
102        } else {
103            m.apply(&mut a, &mut b);
104            m.apply(&mut s0, &mut s1);
105            m.apply(&mut t0, &mut t1);
106            even ^= !m.4;
107        }
108    }
109    // TODO: Compute using absolute value instead of patching sign.
110    if even {
111        // t negative
112        t0 = Uint::ZERO - t0;
113    } else {
114        // s negative
115        s0 = Uint::ZERO - s0;
116    }
117    if swapped {
118        swap(&mut s0, &mut t0);
119        even = !even;
120    }
121    (a, s0, t0, even)
122}
123
124/// ⚠️ Modular inversion using extended GCD.
125///
126/// It uses the Bezout identity
127///
128/// ```text
129///    a * modulus + b * num = gcd(modulus, num)
130/// ````
131///
132/// where `a` and `b` are the cofactors from the extended Euclidean algorithm.
133/// A modular inverse only exists if `modulus` and `num` are coprime, in which
134/// case `gcd(modulus, num)` is one. Reducing both sides by the modulus then
135/// results in the equation `b * num = 1 (mod modulus)`. In other words, the
136/// cofactor `b` is the modular inverse of `num`.
137///
138/// It differs from `gcd_extended` in that it only computes the required
139/// cofactor, and returns `None` if the GCD is not one (i.e. when `num` does
140/// not have an inverse).
141#[inline]
142#[must_use]
143pub fn inv_mod<const BITS: usize, const LIMBS: usize>(
144    num: Uint<BITS, LIMBS>,
145    modulus: Uint<BITS, LIMBS>,
146) -> Option<Uint<BITS, LIMBS>> {
147    if BITS == 0 || modulus == Uint::ZERO {
148        return None;
149    }
150    let mut a = modulus;
151    let mut b = num;
152    if b >= a {
153        b %= a;
154    }
155    if b == Uint::ZERO {
156        return None;
157    }
158
159    let mut t0 = Uint::ZERO;
160    let mut t1 = Uint::from(1);
161    let mut even = true;
162    while b != Uint::ZERO {
163        debug_assert!(a >= b);
164        let m = LehmerMatrix::from(a, b);
165        if m == LehmerMatrix::IDENTITY {
166            // Lehmer step failed to find a factor, which happens when
167            // the factor is very large. We do a regular Euclidean step, which
168            // will make a lot of progress since `q` will be large.
169            let q = a / b;
170            a -= q * b;
171            swap(&mut a, &mut b);
172            t0 -= q * t1;
173            swap(&mut t0, &mut t1);
174            even = !even;
175        } else {
176            m.apply(&mut a, &mut b);
177            m.apply(&mut t0, &mut t1);
178            even ^= !m.4;
179        }
180    }
181    if a == Uint::from(1) {
182        // When `even` t0 is negative and in twos-complement form
183        Some(if even { modulus + t0 } else { t0 })
184    } else {
185        None
186    }
187}
188
189#[cfg(test)]
190#[allow(clippy::cast_lossless)]
191mod tests {
192    use super::*;
193    use crate::{const_for, nlimbs};
194    use core::cmp::min;
195    use proptest::{proptest, test_runner::Config};
196
197    #[test]
198    fn test_gcd_one() {
199        use core::str::FromStr;
200        const BITS: usize = 129;
201        const LIMBS: usize = nlimbs(BITS);
202        type U = Uint<BITS, LIMBS>;
203        let a = U::from_str("0x006d7c4641f88b729a97889164dd8d07db").unwrap();
204        let b = U::from_str("0x01de6ef6f3caa963a548d7a411b05b9988").unwrap();
205        assert_eq!(gcd(a, b), gcd_ref(a, b));
206    }
207
208    // Reference implementation
209    fn gcd_ref<const BITS: usize, const LIMBS: usize>(
210        mut a: Uint<BITS, LIMBS>,
211        mut b: Uint<BITS, LIMBS>,
212    ) -> Uint<BITS, LIMBS> {
213        while b != Uint::ZERO {
214            a %= b;
215            swap(&mut a, &mut b);
216        }
217        a
218    }
219
220    #[test]
221    #[allow(clippy::absurd_extreme_comparisons)] // Generated code
222    fn test_gcd() {
223        const_for!(BITS in SIZES {
224            const LIMBS: usize = nlimbs(BITS);
225            type U = Uint<BITS, LIMBS>;
226            // TODO: Increase cases when perf is better.
227            let mut config = Config::default();
228            config.cases = min(config.cases, if BITS > 500 { 9 } else { 30 });
229            proptest!(config, |(a: U, b: U)| {
230                assert_eq!(gcd(a, b), gcd_ref(a, b));
231            });
232        });
233    }
234
235    #[test]
236    #[allow(clippy::absurd_extreme_comparisons)] // Generated code
237    fn test_gcd_extended() {
238        const_for!(BITS in SIZES {
239            const LIMBS: usize = nlimbs(BITS);
240            type U = Uint<BITS, LIMBS>;
241            // TODO: Increase cases when perf is better.
242            let mut config = Config::default();
243            config.cases = min(config.cases, if BITS > 500 { 3 } else { 10 });
244            proptest!(config, |(a: U, b: U)| {
245                let (g, x, y, sign) = gcd_extended(a, b);
246                assert_eq!(g, gcd_ref(a, b));
247                if sign {
248                    assert_eq!(a * x - b * y, g);
249                } else {
250                    assert_eq!(b * y - a * x, g);
251                }
252            });
253        });
254    }
255}