ruint/algorithms/
mul.rs

1#![allow(clippy::module_name_repetitions)]
2
3use crate::algorithms::{ops::sbb, DoubleWord};
4
5#[inline]
6#[allow(clippy::cast_possible_truncation)] // Intentional truncation.
7#[allow(dead_code)] // Used for testing
8pub fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
9    let mut overflow = 0;
10    for (i, a) in a.iter().copied().enumerate() {
11        let mut result = result.iter_mut().skip(i);
12        let mut b = b.iter().copied();
13        let mut carry = 0_u128;
14        loop {
15            match (result.next(), b.next()) {
16                // Partial product.
17                (Some(result), Some(b)) => {
18                    carry += u128::from(*result) + u128::from(a) * u128::from(b);
19                    *result = carry as u64;
20                    carry >>= 64;
21                }
22                // Carry propagation.
23                (Some(result), None) => {
24                    carry += u128::from(*result);
25                    *result = carry as u64;
26                    carry >>= 64;
27                }
28                // Excess product.
29                (None, Some(b)) => {
30                    carry += u128::from(a) * u128::from(b);
31                    overflow |= carry as u64;
32                    carry >>= 64;
33                }
34                // Fin.
35                (None, None) => {
36                    break;
37                }
38            }
39        }
40        overflow |= carry as u64;
41    }
42    overflow != 0
43}
44
45/// ⚠️ Computes `result += a * b` and checks for overflow.
46///
47/// **Warning.** This function is not part of the stable API.
48///
49/// Arrays are in little-endian order. All arrays can be arbitrary sized.
50///
51/// # Algorithm
52///
53/// Trims zeros from inputs, then uses the schoolbook multiplication algorithm.
54/// It takes the shortest input as the outer loop.
55///
56/// # Examples
57///
58/// ```
59/// # use ruint::algorithms::addmul;
60/// let mut result = [0];
61/// let overflow = addmul(&mut result, &[3], &[4]);
62/// assert_eq!(overflow, false);
63/// assert_eq!(result, [12]);
64/// ```
65#[inline]
66pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
67    // Trim zeros from `a`
68    while let [0, rest @ ..] = a {
69        a = rest;
70        if let [_, rest @ ..] = lhs {
71            lhs = rest;
72        }
73    }
74    while let [rest @ .., 0] = a {
75        a = rest;
76    }
77
78    // Trim zeros from `b`
79    while let [0, rest @ ..] = b {
80        b = rest;
81        if let [_, rest @ ..] = lhs {
82            lhs = rest;
83        }
84    }
85    while let [rest @ .., 0] = b {
86        b = rest;
87    }
88
89    if a.is_empty() || b.is_empty() {
90        return false;
91    }
92    if lhs.is_empty() {
93        return true;
94    }
95
96    let (a, b) = if b.len() > a.len() { (b, a) } else { (a, b) };
97
98    // Iterate over limbs of `b` and add partial products to `lhs`.
99    let mut overflow = false;
100    for &b in b {
101        if lhs.len() >= a.len() {
102            let (target, rest) = lhs.split_at_mut(a.len());
103            let carry = addmul_nx1(target, a, b);
104            let carry = add_nx1(rest, carry);
105            overflow |= carry != 0;
106        } else {
107            overflow = true;
108            if lhs.is_empty() {
109                break;
110            }
111            addmul_nx1(lhs, &a[..lhs.len()], b);
112        }
113        lhs = &mut lhs[1..];
114    }
115    overflow
116}
117
118/// Computes `lhs += a` and returns the carry.
119#[inline]
120pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 {
121    if a == 0 {
122        return 0;
123    }
124    for lhs in lhs {
125        let sum = u128::add(*lhs, a);
126        *lhs = sum.low();
127        a = sum.high();
128        if a == 0 {
129            return 0;
130        }
131    }
132    a
133}
134
135/// Computes wrapping `lhs += a * b` when all arguments are the same length.
136///
137/// # Panics
138///
139/// Panics if the lengts are not the same.
140#[inline(always)]
141pub fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) {
142    assert_eq!(lhs.len(), a.len());
143    assert_eq!(lhs.len(), b.len());
144    match lhs.len() {
145        0 => {}
146        1 => addmul_1(lhs, a, b),
147        2 => addmul_2(lhs, a, b),
148        3 => addmul_3(lhs, a, b),
149        4 => addmul_4(lhs, a, b),
150        _ => {
151            let _ = addmul(lhs, a, b);
152        }
153    }
154}
155
156/// Computes `lhs += a * b` for 1 limb.
157#[inline(always)]
158fn addmul_1(lhs: &mut [u64], a: &[u64], b: &[u64]) {
159    assert_eq!(lhs.len(), 1);
160    assert_eq!(a.len(), 1);
161    assert_eq!(b.len(), 1);
162
163    mac(&mut lhs[0], a[0], b[0], 0);
164}
165
166/// Computes `lhs += a * b` for 2 limbs.
167#[inline(always)]
168fn addmul_2(lhs: &mut [u64], a: &[u64], b: &[u64]) {
169    assert_eq!(lhs.len(), 2);
170    assert_eq!(a.len(), 2);
171    assert_eq!(b.len(), 2);
172
173    let carry = mac(&mut lhs[0], a[0], b[0], 0);
174    mac(&mut lhs[1], a[0], b[1], carry);
175
176    mac(&mut lhs[1], a[1], b[0], 0);
177}
178
179/// Computes `lhs += a * b` for 3 limbs.
180#[inline(always)]
181fn addmul_3(lhs: &mut [u64], a: &[u64], b: &[u64]) {
182    assert_eq!(lhs.len(), 3);
183    assert_eq!(a.len(), 3);
184    assert_eq!(b.len(), 3);
185
186    let carry = mac(&mut lhs[0], a[0], b[0], 0);
187    let carry = mac(&mut lhs[1], a[0], b[1], carry);
188    mac(&mut lhs[2], a[0], b[2], carry);
189
190    let carry = mac(&mut lhs[1], a[1], b[0], 0);
191    mac(&mut lhs[2], a[1], b[1], carry);
192
193    mac(&mut lhs[2], a[2], b[0], 0);
194}
195
196/// Computes `lhs += a * b` for 4 limbs.
197#[inline(always)]
198fn addmul_4(lhs: &mut [u64], a: &[u64], b: &[u64]) {
199    assert_eq!(lhs.len(), 4);
200    assert_eq!(a.len(), 4);
201    assert_eq!(b.len(), 4);
202
203    let carry = mac(&mut lhs[0], a[0], b[0], 0);
204    let carry = mac(&mut lhs[1], a[0], b[1], carry);
205    let carry = mac(&mut lhs[2], a[0], b[2], carry);
206    mac(&mut lhs[3], a[0], b[3], carry);
207
208    let carry = mac(&mut lhs[1], a[1], b[0], 0);
209    let carry = mac(&mut lhs[2], a[1], b[1], carry);
210    mac(&mut lhs[3], a[1], b[2], carry);
211
212    let carry = mac(&mut lhs[2], a[2], b[0], 0);
213    mac(&mut lhs[3], a[2], b[1], carry);
214
215    mac(&mut lhs[3], a[3], b[0], 0);
216}
217
218#[inline(always)]
219fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 {
220    let prod = u128::muladd2(a, b, c, *lhs);
221    *lhs = prod.low();
222    prod.high()
223}
224
225/// Computes `lhs *= a` and returns the carry.
226#[inline]
227pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
228    let mut carry = 0;
229    for lhs in &mut *lhs {
230        let product = u128::muladd(*lhs, a, carry);
231        *lhs = product.low();
232        carry = product.high();
233    }
234    carry
235}
236
237/// Computes `lhs += a * b` and returns the carry.
238///
239/// Requires `lhs.len() == a.len()`.
240///
241/// $$
242/// \begin{aligned}
243/// \mathsf{lhs'} &= \mod{\mathsf{lhs} + \mathsf{a} ⋅ \mathsf{b}}_{2^{64⋅N}}
244/// \\\\ \mathsf{carry} &= \floor{\frac{\mathsf{lhs} + \mathsf{a} ⋅ \mathsf{b}
245/// }{2^{64⋅N}}} \end{aligned}
246/// $$
247#[inline]
248pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
249    debug_assert_eq!(lhs.len(), a.len());
250    let mut carry = 0;
251    for (lhs, a) in lhs.iter_mut().zip(a.iter().copied()) {
252        let product = u128::muladd2(a, b, carry, *lhs);
253        *lhs = product.low();
254        carry = product.high();
255    }
256    carry
257}
258
259/// Computes `lhs -= a * b` and returns the borrow.
260///
261/// Requires `lhs.len() == a.len()`.
262///
263/// $$
264/// \begin{aligned}
265/// \mathsf{lhs'} &= \mod{\mathsf{lhs} - \mathsf{a} ⋅ \mathsf{b}}_{2^{64⋅N}}
266/// \\\\ \mathsf{borrow} &= \floor{\frac{\mathsf{a} ⋅ \mathsf{b} -
267/// \mathsf{lhs}}{2^{64⋅N}}} \end{aligned}
268/// $$
269// OPT: `carry` and `borrow` can probably be merged into a single var.
270#[inline]
271pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
272    debug_assert_eq!(lhs.len(), a.len());
273    let mut carry = 0;
274    let mut borrow = 0;
275    for (lhs, a) in lhs.iter_mut().zip(a.iter().copied()) {
276        // Compute product limbs
277        let limb = {
278            let product = u128::muladd(a, b, carry);
279            carry = product.high();
280            product.low()
281        };
282
283        // Subtract
284        let (new, b) = sbb(*lhs, limb, borrow);
285        *lhs = new;
286        borrow = b;
287    }
288    borrow + carry
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use proptest::{collection, num::u64, proptest};
295
296    #[test]
297    fn test_addmul() {
298        let any_vec = collection::vec(u64::ANY, 0..10);
299        proptest!(|(mut lhs in &any_vec, a in &any_vec, b in &any_vec)| {
300            // Reference
301            let mut ref_lhs = lhs.clone();
302            let ref_overflow = addmul_ref(&mut ref_lhs, &a, &b);
303
304            // Test
305            let overflow = addmul(&mut lhs, &a, &b);
306            assert_eq!(lhs, ref_lhs);
307            assert_eq!(overflow, ref_overflow);
308        });
309    }
310
311    fn test_vals(lhs: &[u64], rhs: &[u64], expected: &[u64], expected_overflow: bool) {
312        let mut result = vec![0; expected.len()];
313        let overflow = addmul(&mut result, lhs, rhs);
314        assert_eq!(overflow, expected_overflow);
315        assert_eq!(result, expected);
316    }
317
318    #[test]
319    fn test_empty() {
320        test_vals(&[], &[], &[], false);
321        test_vals(&[], &[1], &[], false);
322        test_vals(&[1], &[], &[], false);
323        test_vals(&[1], &[1], &[], true);
324        test_vals(&[], &[], &[0], false);
325        test_vals(&[], &[1], &[0], false);
326        test_vals(&[1], &[], &[0], false);
327        test_vals(&[1], &[1], &[1], false);
328    }
329
330    #[test]
331    fn test_submul_nx1() {
332        let mut lhs = [
333            15520854688669198950,
334            13760048731709406392,
335            14363314282014368551,
336            13263184899940581802,
337        ];
338        let a = [
339            7955980792890017645,
340            6297379555503105007,
341            2473663400150304794,
342            18362433840513668572,
343        ];
344        let b = 17275533833223164845;
345        let borrow = submul_nx1(&mut lhs, &a, b);
346        assert_eq!(lhs, [
347            2427453526388035261,
348            7389014268281543265,
349            6670181329660292018,
350            8411211985208067428
351        ]);
352        assert_eq!(borrow, 17196576577663999042);
353    }
354}