minimal_ed448/
point.rs

1use core::{
2  ops::{Add, AddAssign, Neg, Sub, SubAssign, Mul, MulAssign},
3  iter::Sum,
4};
5
6use rand_core::RngCore;
7
8use zeroize::Zeroize;
9use subtle::{Choice, CtOption, ConstantTimeEq, ConditionallySelectable, ConditionallyNegatable};
10
11use crypto_bigint::{U448, modular::constant_mod::Residue};
12
13use group::{
14  ff::{Field, PrimeField, PrimeFieldBits},
15  Group, GroupEncoding,
16  prime::PrimeGroup,
17};
18
19use crate::{
20  backend::u8_from_bool,
21  scalar::Scalar,
22  field::{ResidueType, FieldElement, Q_4},
23};
24
25const D: FieldElement =
26  FieldElement(ResidueType::sub(&ResidueType::ZERO, &Residue::new(&U448::from_u16(39081))));
27
28const G_Y: FieldElement = FieldElement(Residue::new(&U448::from_be_hex(concat!(
29  "693f46716eb6bc248876203756c9c7624bea73736ca3984087789c1e",
30  "05a0c2d73ad3ff1ce67c39c4fdbd132c4ed7c8ad9808795bf230fa14",
31))));
32
33const G_X: FieldElement = FieldElement(Residue::new(&U448::from_be_hex(concat!(
34  "4f1970c66bed0ded221d15a622bf36da9e146570470f1767ea6de324",
35  "a3d3a46412ae1af72ab66511433b80e18b00938e2626a82bc70cc05e",
36))));
37
38fn recover_x(y: FieldElement) -> CtOption<FieldElement> {
39  let ysq = y.square();
40  #[allow(non_snake_case)]
41  let D_ysq = D * ysq;
42  (D_ysq - FieldElement::ONE).invert().and_then(|inverted| {
43    let temp = (ysq - FieldElement::ONE) * inverted;
44    let mut x = temp.pow(Q_4);
45    x.conditional_negate(x.is_odd());
46
47    let xsq = x.square();
48    CtOption::new(x, (xsq + ysq).ct_eq(&(FieldElement::ONE + (xsq * D_ysq))))
49  })
50}
51
52/// Ed448 point.
53#[derive(Clone, Copy, Debug)]
54pub struct Point {
55  x: FieldElement,
56  y: FieldElement,
57  z: FieldElement,
58}
59
60impl Zeroize for Point {
61  fn zeroize(&mut self) {
62    self.x.zeroize();
63    self.y.zeroize();
64    self.z.zeroize();
65    let identity = Self::identity();
66    self.x = identity.x;
67    self.y = identity.y;
68    self.z = identity.z;
69  }
70}
71
72const G: Point = Point { x: G_X, y: G_Y, z: FieldElement::ONE };
73
74impl ConstantTimeEq for Point {
75  fn ct_eq(&self, other: &Self) -> Choice {
76    let x1 = self.x * other.z;
77    let x2 = other.x * self.z;
78
79    let y1 = self.y * other.z;
80    let y2 = other.y * self.z;
81
82    x1.ct_eq(&x2) & y1.ct_eq(&y2)
83  }
84}
85
86impl PartialEq for Point {
87  fn eq(&self, other: &Point) -> bool {
88    self.ct_eq(other).into()
89  }
90}
91
92impl Eq for Point {}
93
94impl ConditionallySelectable for Point {
95  fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
96    Point {
97      x: FieldElement::conditional_select(&a.x, &b.x, choice),
98      y: FieldElement::conditional_select(&a.y, &b.y, choice),
99      z: FieldElement::conditional_select(&a.z, &b.z, choice),
100    }
101  }
102}
103
104impl Add for Point {
105  type Output = Point;
106  fn add(self, other: Self) -> Self {
107    // 12 muls, 7 additions, 4 negations
108    let xcp = self.x * other.x;
109    let ycp = self.y * other.y;
110    let zcp = self.z * other.z;
111    #[allow(non_snake_case)]
112    let B = zcp.square();
113    #[allow(non_snake_case)]
114    let E = D * xcp * ycp;
115    #[allow(non_snake_case)]
116    let F = B - E;
117    #[allow(non_snake_case)]
118    let G_ = B + E;
119
120    Point {
121      x: zcp * F * ((self.x + self.y) * (other.x + other.y) - xcp - ycp),
122      y: zcp * G_ * (ycp - xcp),
123      z: F * G_,
124    }
125  }
126}
127
128impl AddAssign for Point {
129  fn add_assign(&mut self, other: Point) {
130    *self = *self + other;
131  }
132}
133
134impl Add<&Point> for Point {
135  type Output = Point;
136  fn add(self, other: &Point) -> Point {
137    self + *other
138  }
139}
140
141impl AddAssign<&Point> for Point {
142  fn add_assign(&mut self, other: &Point) {
143    *self += *other;
144  }
145}
146
147impl Neg for Point {
148  type Output = Point;
149  fn neg(self) -> Self {
150    Point { x: -self.x, y: self.y, z: self.z }
151  }
152}
153
154impl Sub for Point {
155  type Output = Point;
156  #[allow(clippy::suspicious_arithmetic_impl)]
157  fn sub(self, other: Self) -> Self {
158    self + other.neg()
159  }
160}
161
162impl SubAssign for Point {
163  fn sub_assign(&mut self, other: Point) {
164    *self = *self - other;
165  }
166}
167
168impl Sub<&Point> for Point {
169  type Output = Point;
170  fn sub(self, other: &Point) -> Point {
171    self - *other
172  }
173}
174
175impl SubAssign<&Point> for Point {
176  fn sub_assign(&mut self, other: &Point) {
177    *self -= *other;
178  }
179}
180
181impl Group for Point {
182  type Scalar = Scalar;
183  fn random(mut rng: impl RngCore) -> Self {
184    loop {
185      let mut bytes = FieldElement::random(&mut rng).to_repr();
186      let mut_ref: &mut [u8] = bytes.as_mut();
187      mut_ref[56] |= u8::try_from(rng.next_u32() % 2).unwrap() << 7;
188      let opt = Self::from_bytes(&bytes);
189      if opt.is_some().into() {
190        return opt.unwrap();
191      }
192    }
193  }
194  fn identity() -> Self {
195    Point { x: FieldElement::ZERO, y: FieldElement::ONE, z: FieldElement::ONE }
196  }
197  fn generator() -> Self {
198    G
199  }
200  fn is_identity(&self) -> Choice {
201    self.ct_eq(&Self::identity())
202  }
203  fn double(&self) -> Self {
204    // 7 muls, 7 additions, 4 negations
205    let xsq = self.x.square();
206    let ysq = self.y.square();
207    let zsq = self.z.square();
208    let xy = self.x + self.y;
209    #[allow(non_snake_case)]
210    let F = xsq + ysq;
211    #[allow(non_snake_case)]
212    let J = F - zsq.double();
213    Point { x: J * (xy.square() - xsq - ysq), y: F * (xsq - ysq), z: F * J }
214  }
215}
216
217impl Sum<Point> for Point {
218  fn sum<I: Iterator<Item = Point>>(iter: I) -> Point {
219    let mut res = Self::identity();
220    for i in iter {
221      res += i;
222    }
223    res
224  }
225}
226
227impl<'a> Sum<&'a Point> for Point {
228  fn sum<I: Iterator<Item = &'a Point>>(iter: I) -> Point {
229    Point::sum(iter.copied())
230  }
231}
232
233impl Mul<Scalar> for Point {
234  type Output = Point;
235  fn mul(self, mut other: Scalar) -> Point {
236    // Precompute the optimal amount that's a multiple of 2
237    let mut table = [Point::identity(); 16];
238    table[1] = self;
239    for i in 2 .. 16 {
240      table[i] = table[i - 1] + self;
241    }
242
243    let mut res = Self::identity();
244    let mut bits = 0;
245    for (i, mut bit) in other.to_le_bits().iter_mut().rev().enumerate() {
246      bits <<= 1;
247      let mut bit = u8_from_bool(&mut bit);
248      bits |= bit;
249      bit.zeroize();
250
251      if ((i + 1) % 4) == 0 {
252        if i != 3 {
253          for _ in 0 .. 4 {
254            res = res.double();
255          }
256        }
257
258        let mut add_by = Point::identity();
259        #[allow(clippy::needless_range_loop)]
260        for i in 0 .. 16 {
261          #[allow(clippy::cast_possible_truncation)] // Safe since 0 .. 16
262          {
263            add_by = <_>::conditional_select(&add_by, &table[i], bits.ct_eq(&(i as u8)));
264          }
265        }
266        res += add_by;
267        bits = 0;
268      }
269    }
270    other.zeroize();
271    res
272  }
273}
274
275impl MulAssign<Scalar> for Point {
276  fn mul_assign(&mut self, other: Scalar) {
277    *self = *self * other;
278  }
279}
280
281impl Mul<&Scalar> for Point {
282  type Output = Point;
283  fn mul(self, other: &Scalar) -> Point {
284    self * *other
285  }
286}
287
288impl MulAssign<&Scalar> for Point {
289  fn mul_assign(&mut self, other: &Scalar) {
290    *self *= *other;
291  }
292}
293
294impl Point {
295  fn is_torsion_free(&self) -> Choice {
296    ((*self * (Scalar::ZERO - Scalar::ONE)) + self).is_identity()
297  }
298}
299
300impl GroupEncoding for Point {
301  type Repr = <FieldElement as PrimeField>::Repr;
302
303  fn from_bytes(bytes: &Self::Repr) -> CtOption<Self> {
304    // Extract and clear the sign bit
305    let sign = Choice::from(bytes[56] >> 7);
306    let mut bytes = *bytes;
307    let mut_ref: &mut [u8] = bytes.as_mut();
308    mut_ref[56] &= !(1 << 7);
309
310    // Parse y, recover x
311    FieldElement::from_repr(bytes).and_then(|y| {
312      recover_x(y).and_then(|mut x| {
313        x.conditional_negate(x.is_odd().ct_eq(&!sign));
314        let not_negative_zero = !(x.is_zero() & sign);
315        let point = Point { x, y, z: FieldElement::ONE };
316        CtOption::new(point, not_negative_zero & point.is_torsion_free())
317      })
318    })
319  }
320
321  fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption<Self> {
322    Point::from_bytes(bytes)
323  }
324
325  fn to_bytes(&self) -> Self::Repr {
326    let z = self.z.invert().unwrap();
327    let x = self.x * z;
328    let y = self.y * z;
329
330    let mut bytes = y.to_repr();
331    let mut_ref: &mut [u8] = bytes.as_mut();
332    mut_ref[56] |= x.is_odd().unwrap_u8() << 7;
333    bytes
334  }
335}
336
337impl PrimeGroup for Point {}
338
339#[test]
340fn test_group() {
341  ff_group_tests::group::test_prime_group_bits::<_, Point>(&mut rand_core::OsRng);
342}
343
344#[test]
345fn generator() {
346  assert!(G.x == G_X);
347  assert!(G.y == G_Y);
348  assert!(recover_x(G.y).unwrap() == G.x);
349}
350
351#[test]
352fn torsion() {
353  use generic_array::GenericArray;
354
355  // Uses the originally suggested generator which had torsion
356  let old_y = FieldElement::from_repr(*GenericArray::from_slice(
357    &hex::decode(
358      "\
35912796c1532041525945f322e414d434467cfd5c57c9a9af2473b2775\
3608c921c4828b277ca5f2891fc4f3d79afdf29a64c72fb28b59c16fa51\
36100",
362    )
363    .unwrap(),
364  ))
365  .unwrap();
366  let old = Point { x: -recover_x(old_y).unwrap(), y: old_y, z: FieldElement::ONE };
367  assert!(bool::from(!old.is_torsion_free()));
368}
369
370#[test]
371fn vector() {
372  use generic_array::GenericArray;
373
374  assert_eq!(
375    Point::generator().double(),
376    Point::from_bytes(GenericArray::from_slice(
377      &hex::decode(
378        "\
379ed8693eacdfbeada6ba0cdd1beb2bcbb98302a3a8365650db8c4d88a\
380726de3b7d74d8835a0d76e03b0c2865020d659b38d04d74a63e905ae\
38180"
382      )
383      .unwrap()
384    ))
385    .unwrap()
386  );
387
388  assert_eq!(
389    Point::generator() *
390      Scalar::from_repr(*GenericArray::from_slice(
391        &hex::decode(
392          "\
3936298e1eef3c379392caaed061ed8a31033c9e9e3420726f23b404158\
394a401cd9df24632adfe6b418dc942d8a091817dd8bd70e1c72ba52f3c\
39500"
396        )
397        .unwrap()
398      ))
399      .unwrap(),
400    Point::from_bytes(GenericArray::from_slice(
401      &hex::decode(
402        "\
4033832f82fda00ff5365b0376df705675b63d2a93c24c6e81d40801ba2\
40465632be10f443f95968fadb70d10786827f30dc001c8d0f9b7c1d1b0\
40500"
406      )
407      .unwrap()
408    ))
409    .unwrap()
410  );
411}
412
413// Checks random won't infinitely loop
414#[test]
415fn random() {
416  Point::random(&mut rand_core::OsRng);
417}