ruint/algorithms/gcd/
mod.rs1#![allow(clippy::module_name_repetitions)]
2
3mod matrix;
5
6pub use self::matrix::Matrix as LehmerMatrix;
7use crate::Uint;
8use core::mem::swap;
9
10#[inline]
16#[must_use]
17pub fn gcd<const BITS: usize, const LIMBS: usize>(
18 mut a: Uint<BITS, LIMBS>,
19 mut b: Uint<BITS, LIMBS>,
20) -> Uint<BITS, LIMBS> {
21 if b > a {
22 swap(&mut a, &mut b);
23 }
24 while b != Uint::ZERO {
25 debug_assert!(a >= b);
26 let m = LehmerMatrix::from(a, b);
27 if m == LehmerMatrix::IDENTITY {
28 a %= b;
32 swap(&mut a, &mut b);
33 } else {
34 m.apply(&mut a, &mut b);
35 }
36 }
37 a
38}
39
40#[inline]
63#[must_use]
64pub fn gcd_extended<const BITS: usize, const LIMBS: usize>(
65 mut a: Uint<BITS, LIMBS>,
66 mut b: Uint<BITS, LIMBS>,
67) -> (
68 Uint<BITS, LIMBS>,
69 Uint<BITS, LIMBS>,
70 Uint<BITS, LIMBS>,
71 bool,
72) {
73 if BITS == 0 {
74 return (Uint::ZERO, Uint::ZERO, Uint::ZERO, false);
75 }
76 let swapped = a < b;
77 if swapped {
78 swap(&mut a, &mut b);
79 }
80
81 let mut s0 = Uint::from(1);
83 let mut s1 = Uint::ZERO;
84 let mut t0 = Uint::ZERO;
85 let mut t1 = Uint::from(1);
86 let mut even = true;
87 while b != Uint::ZERO {
88 debug_assert!(a >= b);
89 let m = LehmerMatrix::from(a, b);
90 if m == LehmerMatrix::IDENTITY {
91 let q = a / b;
95 a -= q * b;
96 swap(&mut a, &mut b);
97 s0 -= q * s1;
98 swap(&mut s0, &mut s1);
99 t0 -= q * t1;
100 swap(&mut t0, &mut t1);
101 even = !even;
102 } else {
103 m.apply(&mut a, &mut b);
104 m.apply(&mut s0, &mut s1);
105 m.apply(&mut t0, &mut t1);
106 even ^= !m.4;
107 }
108 }
109 if even {
111 t0 = Uint::ZERO - t0;
113 } else {
114 s0 = Uint::ZERO - s0;
116 }
117 if swapped {
118 swap(&mut s0, &mut t0);
119 even = !even;
120 }
121 (a, s0, t0, even)
122}
123
124#[inline]
142#[must_use]
143pub fn inv_mod<const BITS: usize, const LIMBS: usize>(
144 num: Uint<BITS, LIMBS>,
145 modulus: Uint<BITS, LIMBS>,
146) -> Option<Uint<BITS, LIMBS>> {
147 if BITS == 0 || modulus == Uint::ZERO {
148 return None;
149 }
150 let mut a = modulus;
151 let mut b = num;
152 if b >= a {
153 b %= a;
154 }
155 if b == Uint::ZERO {
156 return None;
157 }
158
159 let mut t0 = Uint::ZERO;
160 let mut t1 = Uint::from(1);
161 let mut even = true;
162 while b != Uint::ZERO {
163 debug_assert!(a >= b);
164 let m = LehmerMatrix::from(a, b);
165 if m == LehmerMatrix::IDENTITY {
166 let q = a / b;
170 a -= q * b;
171 swap(&mut a, &mut b);
172 t0 -= q * t1;
173 swap(&mut t0, &mut t1);
174 even = !even;
175 } else {
176 m.apply(&mut a, &mut b);
177 m.apply(&mut t0, &mut t1);
178 even ^= !m.4;
179 }
180 }
181 if a == Uint::from(1) {
182 Some(if even { modulus + t0 } else { t0 })
184 } else {
185 None
186 }
187}
188
189#[cfg(test)]
190#[allow(clippy::cast_lossless)]
191mod tests {
192 use super::*;
193 use crate::{const_for, nlimbs};
194 use core::cmp::min;
195 use proptest::{proptest, test_runner::Config};
196
197 #[test]
198 fn test_gcd_one() {
199 use core::str::FromStr;
200 const BITS: usize = 129;
201 const LIMBS: usize = nlimbs(BITS);
202 type U = Uint<BITS, LIMBS>;
203 let a = U::from_str("0x006d7c4641f88b729a97889164dd8d07db").unwrap();
204 let b = U::from_str("0x01de6ef6f3caa963a548d7a411b05b9988").unwrap();
205 assert_eq!(gcd(a, b), gcd_ref(a, b));
206 }
207
208 fn gcd_ref<const BITS: usize, const LIMBS: usize>(
210 mut a: Uint<BITS, LIMBS>,
211 mut b: Uint<BITS, LIMBS>,
212 ) -> Uint<BITS, LIMBS> {
213 while b != Uint::ZERO {
214 a %= b;
215 swap(&mut a, &mut b);
216 }
217 a
218 }
219
220 #[test]
221 #[allow(clippy::absurd_extreme_comparisons)] fn test_gcd() {
223 const_for!(BITS in SIZES {
224 const LIMBS: usize = nlimbs(BITS);
225 type U = Uint<BITS, LIMBS>;
226 let mut config = Config::default();
228 config.cases = min(config.cases, if BITS > 500 { 9 } else { 30 });
229 proptest!(config, |(a: U, b: U)| {
230 assert_eq!(gcd(a, b), gcd_ref(a, b));
231 });
232 });
233 }
234
235 #[test]
236 #[allow(clippy::absurd_extreme_comparisons)] fn test_gcd_extended() {
238 const_for!(BITS in SIZES {
239 const LIMBS: usize = nlimbs(BITS);
240 type U = Uint<BITS, LIMBS>;
241 let mut config = Config::default();
243 config.cases = min(config.cases, if BITS > 500 { 3 } else { 10 });
244 proptest!(config, |(a: U, b: U)| {
245 let (g, x, y, sign) = gcd_extended(a, b);
246 assert_eq!(g, gcd_ref(a, b));
247 if sign {
248 assert_eq!(a * x - b * y, g);
249 } else {
250 assert_eq!(b * y - a * x, g);
251 }
252 });
253 });
254 }
255}