1#![allow(clippy::module_name_repetitions)]
2
3use crate::algorithms::{ops::sbb, DoubleWord};
4
5#[inline]
6#[allow(clippy::cast_possible_truncation)] #[allow(dead_code)] pub fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
9 let mut overflow = 0;
10 for (i, a) in a.iter().copied().enumerate() {
11 let mut result = result.iter_mut().skip(i);
12 let mut b = b.iter().copied();
13 let mut carry = 0_u128;
14 loop {
15 match (result.next(), b.next()) {
16 (Some(result), Some(b)) => {
18 carry += u128::from(*result) + u128::from(a) * u128::from(b);
19 *result = carry as u64;
20 carry >>= 64;
21 }
22 (Some(result), None) => {
24 carry += u128::from(*result);
25 *result = carry as u64;
26 carry >>= 64;
27 }
28 (None, Some(b)) => {
30 carry += u128::from(a) * u128::from(b);
31 overflow |= carry as u64;
32 carry >>= 64;
33 }
34 (None, None) => {
36 break;
37 }
38 }
39 }
40 overflow |= carry as u64;
41 }
42 overflow != 0
43}
44
45#[inline]
66pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
67 while let [0, rest @ ..] = a {
69 a = rest;
70 if let [_, rest @ ..] = lhs {
71 lhs = rest;
72 }
73 }
74 while let [rest @ .., 0] = a {
75 a = rest;
76 }
77
78 while let [0, rest @ ..] = b {
80 b = rest;
81 if let [_, rest @ ..] = lhs {
82 lhs = rest;
83 }
84 }
85 while let [rest @ .., 0] = b {
86 b = rest;
87 }
88
89 if a.is_empty() || b.is_empty() {
90 return false;
91 }
92 if lhs.is_empty() {
93 return true;
94 }
95
96 let (a, b) = if b.len() > a.len() { (b, a) } else { (a, b) };
97
98 let mut overflow = false;
100 for &b in b {
101 if lhs.len() >= a.len() {
102 let (target, rest) = lhs.split_at_mut(a.len());
103 let carry = addmul_nx1(target, a, b);
104 let carry = add_nx1(rest, carry);
105 overflow |= carry != 0;
106 } else {
107 overflow = true;
108 if lhs.is_empty() {
109 break;
110 }
111 addmul_nx1(lhs, &a[..lhs.len()], b);
112 }
113 lhs = &mut lhs[1..];
114 }
115 overflow
116}
117
118#[inline]
120pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 {
121 if a == 0 {
122 return 0;
123 }
124 for lhs in lhs {
125 let sum = u128::add(*lhs, a);
126 *lhs = sum.low();
127 a = sum.high();
128 if a == 0 {
129 return 0;
130 }
131 }
132 a
133}
134
135#[inline(always)]
141pub fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) {
142 assert_eq!(lhs.len(), a.len());
143 assert_eq!(lhs.len(), b.len());
144 match lhs.len() {
145 0 => {}
146 1 => addmul_1(lhs, a, b),
147 2 => addmul_2(lhs, a, b),
148 3 => addmul_3(lhs, a, b),
149 4 => addmul_4(lhs, a, b),
150 _ => {
151 let _ = addmul(lhs, a, b);
152 }
153 }
154}
155
156#[inline(always)]
158fn addmul_1(lhs: &mut [u64], a: &[u64], b: &[u64]) {
159 assert_eq!(lhs.len(), 1);
160 assert_eq!(a.len(), 1);
161 assert_eq!(b.len(), 1);
162
163 mac(&mut lhs[0], a[0], b[0], 0);
164}
165
166#[inline(always)]
168fn addmul_2(lhs: &mut [u64], a: &[u64], b: &[u64]) {
169 assert_eq!(lhs.len(), 2);
170 assert_eq!(a.len(), 2);
171 assert_eq!(b.len(), 2);
172
173 let carry = mac(&mut lhs[0], a[0], b[0], 0);
174 mac(&mut lhs[1], a[0], b[1], carry);
175
176 mac(&mut lhs[1], a[1], b[0], 0);
177}
178
179#[inline(always)]
181fn addmul_3(lhs: &mut [u64], a: &[u64], b: &[u64]) {
182 assert_eq!(lhs.len(), 3);
183 assert_eq!(a.len(), 3);
184 assert_eq!(b.len(), 3);
185
186 let carry = mac(&mut lhs[0], a[0], b[0], 0);
187 let carry = mac(&mut lhs[1], a[0], b[1], carry);
188 mac(&mut lhs[2], a[0], b[2], carry);
189
190 let carry = mac(&mut lhs[1], a[1], b[0], 0);
191 mac(&mut lhs[2], a[1], b[1], carry);
192
193 mac(&mut lhs[2], a[2], b[0], 0);
194}
195
196#[inline(always)]
198fn addmul_4(lhs: &mut [u64], a: &[u64], b: &[u64]) {
199 assert_eq!(lhs.len(), 4);
200 assert_eq!(a.len(), 4);
201 assert_eq!(b.len(), 4);
202
203 let carry = mac(&mut lhs[0], a[0], b[0], 0);
204 let carry = mac(&mut lhs[1], a[0], b[1], carry);
205 let carry = mac(&mut lhs[2], a[0], b[2], carry);
206 mac(&mut lhs[3], a[0], b[3], carry);
207
208 let carry = mac(&mut lhs[1], a[1], b[0], 0);
209 let carry = mac(&mut lhs[2], a[1], b[1], carry);
210 mac(&mut lhs[3], a[1], b[2], carry);
211
212 let carry = mac(&mut lhs[2], a[2], b[0], 0);
213 mac(&mut lhs[3], a[2], b[1], carry);
214
215 mac(&mut lhs[3], a[3], b[0], 0);
216}
217
218#[inline(always)]
219fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 {
220 let prod = u128::muladd2(a, b, c, *lhs);
221 *lhs = prod.low();
222 prod.high()
223}
224
225#[inline]
227pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
228 let mut carry = 0;
229 for lhs in &mut *lhs {
230 let product = u128::muladd(*lhs, a, carry);
231 *lhs = product.low();
232 carry = product.high();
233 }
234 carry
235}
236
237#[inline]
248pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
249 debug_assert_eq!(lhs.len(), a.len());
250 let mut carry = 0;
251 for (lhs, a) in lhs.iter_mut().zip(a.iter().copied()) {
252 let product = u128::muladd2(a, b, carry, *lhs);
253 *lhs = product.low();
254 carry = product.high();
255 }
256 carry
257}
258
259#[inline]
271pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
272 debug_assert_eq!(lhs.len(), a.len());
273 let mut carry = 0;
274 let mut borrow = 0;
275 for (lhs, a) in lhs.iter_mut().zip(a.iter().copied()) {
276 let limb = {
278 let product = u128::muladd(a, b, carry);
279 carry = product.high();
280 product.low()
281 };
282
283 let (new, b) = sbb(*lhs, limb, borrow);
285 *lhs = new;
286 borrow = b;
287 }
288 borrow + carry
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use proptest::{collection, num::u64, proptest};
295
296 #[test]
297 fn test_addmul() {
298 let any_vec = collection::vec(u64::ANY, 0..10);
299 proptest!(|(mut lhs in &any_vec, a in &any_vec, b in &any_vec)| {
300 let mut ref_lhs = lhs.clone();
302 let ref_overflow = addmul_ref(&mut ref_lhs, &a, &b);
303
304 let overflow = addmul(&mut lhs, &a, &b);
306 assert_eq!(lhs, ref_lhs);
307 assert_eq!(overflow, ref_overflow);
308 });
309 }
310
311 fn test_vals(lhs: &[u64], rhs: &[u64], expected: &[u64], expected_overflow: bool) {
312 let mut result = vec![0; expected.len()];
313 let overflow = addmul(&mut result, lhs, rhs);
314 assert_eq!(overflow, expected_overflow);
315 assert_eq!(result, expected);
316 }
317
318 #[test]
319 fn test_empty() {
320 test_vals(&[], &[], &[], false);
321 test_vals(&[], &[1], &[], false);
322 test_vals(&[1], &[], &[], false);
323 test_vals(&[1], &[1], &[], true);
324 test_vals(&[], &[], &[0], false);
325 test_vals(&[], &[1], &[0], false);
326 test_vals(&[1], &[], &[0], false);
327 test_vals(&[1], &[1], &[1], false);
328 }
329
330 #[test]
331 fn test_submul_nx1() {
332 let mut lhs = [
333 15520854688669198950,
334 13760048731709406392,
335 14363314282014368551,
336 13263184899940581802,
337 ];
338 let a = [
339 7955980792890017645,
340 6297379555503105007,
341 2473663400150304794,
342 18362433840513668572,
343 ];
344 let b = 17275533833223164845;
345 let borrow = submul_nx1(&mut lhs, &a, b);
346 assert_eq!(lhs, [
347 2427453526388035261,
348 7389014268281543265,
349 6670181329660292018,
350 8411211985208067428
351 ]);
352 assert_eq!(borrow, 17196576577663999042);
353 }
354}