1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![cfg_attr(not(feature = "std"), no_std)]
#![doc = include_str!("../README.md")]

use core::ops::Deref;

use rand_core::{RngCore, CryptoRng};

use zeroize::{Zeroize, Zeroizing};

use transcript::Transcript;

use ff::{Field, PrimeField};
use group::prime::PrimeGroup;

#[cfg(feature = "serialize")]
use std::io::{self, Error, Read, Write};

/// A cross-group DLEq proof capable of proving that two public keys, across two different curves,
/// share a private key.
#[cfg(feature = "experimental")]
pub mod cross_group;

#[cfg(test)]
mod tests;

// Produce a non-biased challenge from the transcript in the specified field
pub(crate) fn challenge<T: Transcript, F: PrimeField>(transcript: &mut T) -> F {
  // From here, there are three ways to get a scalar under the ff/group API
  // 1: Scalar::random(ChaCha20Rng::from_seed(self.transcript.rng_seed(b"challenge")))
  // 2: Grabbing a UInt library to perform reduction by the modulus, then determining endianness
  //    and loading it in
  // 3: Iterating over each byte and manually doubling/adding. This is simplest

  let mut challenge = F::ZERO;

  // Get a wide amount of bytes to safely reduce without bias
  // In most cases, <=1.5x bytes is enough. 2x is still standard and there's some theoretical
  // groups which may technically require more than 1.5x bytes for this to work as intended
  let target_bytes = ((usize::try_from(F::NUM_BITS).unwrap() + 7) / 8) * 2;
  let mut challenge_bytes = transcript.challenge(b"challenge");
  let challenge_bytes_len = challenge_bytes.as_ref().len();
  // If the challenge is 32 bytes, and we need 64, we need two challenges
  let needed_challenges = (target_bytes + (challenge_bytes_len - 1)) / challenge_bytes_len;

  // The following algorithm should be equivalent to a wide reduction of the challenges,
  // interpreted as concatenated, big-endian byte string
  let mut handled_bytes = 0;
  'outer: for _ in 0 ..= needed_challenges {
    // Cursor of which byte of the challenge to use next
    let mut b = 0;
    while b < challenge_bytes_len {
      // Get the next amount of bytes to attempt
      // Only grabs the needed amount of bytes, up to 8 at a time (u64), so long as they're
      // available in the challenge
      let chunk_bytes = (target_bytes - handled_bytes).min(8).min(challenge_bytes_len - b);

      let mut chunk = 0;
      for _ in 0 .. chunk_bytes {
        chunk <<= 8;
        chunk |= u64::from(challenge_bytes.as_ref()[b]);
        b += 1;
      }
      // Add this chunk
      challenge += F::from(chunk);

      handled_bytes += chunk_bytes;
      // If we've reached the target amount of bytes, break
      if handled_bytes == target_bytes {
        break 'outer;
      }

      // Shift over by however many bits will be in the next chunk
      let next_chunk_bytes = (target_bytes - handled_bytes).min(8).min(challenge_bytes_len);
      for _ in 0 .. (next_chunk_bytes * 8) {
        challenge = challenge.double();
      }
    }

    // Secure thanks to the Transcript trait having a bound of updating on challenge
    challenge_bytes = transcript.challenge(b"challenge_extension");
  }

  challenge
}

// Helper function to read a scalar
#[cfg(feature = "serialize")]
fn read_scalar<R: Read, F: PrimeField>(r: &mut R) -> io::Result<F> {
  let mut repr = F::Repr::default();
  r.read_exact(repr.as_mut())?;
  let scalar = F::from_repr(repr);
  if scalar.is_none().into() {
    Err(Error::other("invalid scalar"))?;
  }
  Ok(scalar.unwrap())
}

/// Error for DLEq proofs.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum DLEqError {
  /// The proof was invalid.
  InvalidProof,
}

/// A proof that points have the same discrete logarithm across generators.
#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroize)]
pub struct DLEqProof<G: PrimeGroup<Scalar: Zeroize>> {
  c: G::Scalar,
  s: G::Scalar,
}

#[allow(non_snake_case)]
impl<G: PrimeGroup<Scalar: Zeroize>> DLEqProof<G> {
  fn transcript<T: Transcript>(transcript: &mut T, generator: G, nonce: G, point: G) {
    transcript.append_message(b"generator", generator.to_bytes());
    transcript.append_message(b"nonce", nonce.to_bytes());
    transcript.append_message(b"point", point.to_bytes());
  }

  /// Prove that the points created by `scalar * G`, for each specified generator, share a discrete
  /// logarithm.
  pub fn prove<R: RngCore + CryptoRng, T: Transcript>(
    rng: &mut R,
    transcript: &mut T,
    generators: &[G],
    scalar: &Zeroizing<G::Scalar>,
  ) -> DLEqProof<G> {
    let r = Zeroizing::new(G::Scalar::random(rng));

    transcript.domain_separate(b"dleq");
    for generator in generators {
      // R, A
      Self::transcript(transcript, *generator, *generator * r.deref(), *generator * scalar.deref());
    }

    let c = challenge(transcript);
    // r + ca
    let s = (c * scalar.deref()) + r.deref();

    DLEqProof { c, s }
  }

  // Transcript a specific generator/nonce/point (G/R/A), as used when verifying a proof.
  // This takes in the generator/point, and then the challenge and solution to calculate the nonce.
  fn verify_statement<T: Transcript>(
    transcript: &mut T,
    generator: G,
    point: G,
    c: G::Scalar,
    s: G::Scalar,
  ) {
    // s = r + ca
    // sG - cA = R
    // R, A
    Self::transcript(transcript, generator, (generator * s) - (point * c), point);
  }

  /// Verify the specified points share a discrete logarithm across the specified generators.
  pub fn verify<T: Transcript>(
    &self,
    transcript: &mut T,
    generators: &[G],
    points: &[G],
  ) -> Result<(), DLEqError> {
    if generators.len() != points.len() {
      Err(DLEqError::InvalidProof)?;
    }

    transcript.domain_separate(b"dleq");
    for (generator, point) in generators.iter().zip(points) {
      Self::verify_statement(transcript, *generator, *point, self.c, self.s);
    }

    if self.c != challenge(transcript) {
      Err(DLEqError::InvalidProof)?;
    }

    Ok(())
  }

  /// Write a DLEq proof to something implementing Write.
  #[cfg(feature = "serialize")]
  pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
    w.write_all(self.c.to_repr().as_ref())?;
    w.write_all(self.s.to_repr().as_ref())
  }

  /// Read a DLEq proof from something implementing Read.
  #[cfg(feature = "serialize")]
  pub fn read<R: Read>(r: &mut R) -> io::Result<DLEqProof<G>> {
    Ok(DLEqProof { c: read_scalar(r)?, s: read_scalar(r)? })
  }

  /// Serialize a DLEq proof to a `Vec<u8>`.
  #[cfg(feature = "serialize")]
  pub fn serialize(&self) -> Vec<u8> {
    let mut res = vec![];
    self.write(&mut res).unwrap();
    res
  }
}

/// A proof that multiple series of points each have a single discrete logarithm across generators.
///
/// This is effectively n distinct DLEq proofs, one for each discrete logarithm and its points
/// across some generators, yet with a smaller overall proof size.
#[cfg(feature = "std")]
#[derive(Clone, PartialEq, Eq, Debug, Zeroize)]
pub struct MultiDLEqProof<G: PrimeGroup<Scalar: Zeroize>> {
  c: G::Scalar,
  s: Vec<G::Scalar>,
}

#[cfg(feature = "std")]
#[allow(non_snake_case)]
impl<G: PrimeGroup<Scalar: Zeroize>> MultiDLEqProof<G> {
  /// Prove for each scalar that the series of points created by multiplying it against its
  /// matching generators share a discrete logarithm.
  /// This function panics if `generators.len() != scalars.len()`.
  pub fn prove<R: RngCore + CryptoRng, T: Transcript>(
    rng: &mut R,
    transcript: &mut T,
    generators: &[Vec<G>],
    scalars: &[Zeroizing<G::Scalar>],
  ) -> MultiDLEqProof<G> {
    assert_eq!(
      generators.len(),
      scalars.len(),
      "amount of series of generators doesn't match the amount of scalars"
    );

    transcript.domain_separate(b"multi_dleq");

    let mut nonces = vec![];
    for (i, (scalar, generators)) in scalars.iter().zip(generators).enumerate() {
      // Delineate between discrete logarithms
      transcript.append_message(b"discrete_logarithm", i.to_le_bytes());

      let nonce = Zeroizing::new(G::Scalar::random(&mut *rng));
      for generator in generators {
        DLEqProof::transcript(
          transcript,
          *generator,
          *generator * nonce.deref(),
          *generator * scalar.deref(),
        );
      }
      nonces.push(nonce);
    }

    let c = challenge(transcript);

    let mut s = vec![];
    for (scalar, nonce) in scalars.iter().zip(nonces) {
      s.push((c * scalar.deref()) + nonce.deref());
    }

    MultiDLEqProof { c, s }
  }

  /// Verify each series of points share a discrete logarithm against their matching series of
  /// generators.
  pub fn verify<T: Transcript>(
    &self,
    transcript: &mut T,
    generators: &[Vec<G>],
    points: &[Vec<G>],
  ) -> Result<(), DLEqError> {
    if points.len() != generators.len() {
      Err(DLEqError::InvalidProof)?;
    }
    if self.s.len() != generators.len() {
      Err(DLEqError::InvalidProof)?;
    }

    transcript.domain_separate(b"multi_dleq");
    for (i, (generators, points)) in generators.iter().zip(points).enumerate() {
      if points.len() != generators.len() {
        Err(DLEqError::InvalidProof)?;
      }

      transcript.append_message(b"discrete_logarithm", i.to_le_bytes());
      for (generator, point) in generators.iter().zip(points) {
        DLEqProof::verify_statement(transcript, *generator, *point, self.c, self.s[i]);
      }
    }

    if self.c != challenge(transcript) {
      Err(DLEqError::InvalidProof)?;
    }

    Ok(())
  }

  /// Write a multi-DLEq proof to something implementing Write.
  #[cfg(feature = "serialize")]
  pub fn write<W: Write>(&self, w: &mut W) -> io::Result<()> {
    w.write_all(self.c.to_repr().as_ref())?;
    for s in &self.s {
      w.write_all(s.to_repr().as_ref())?;
    }
    Ok(())
  }

  /// Read a multi-DLEq proof from something implementing Read.
  #[cfg(feature = "serialize")]
  pub fn read<R: Read>(r: &mut R, discrete_logs: usize) -> io::Result<MultiDLEqProof<G>> {
    let c = read_scalar(r)?;
    let mut s = vec![];
    for _ in 0 .. discrete_logs {
      s.push(read_scalar(r)?);
    }
    Ok(MultiDLEqProof { c, s })
  }

  /// Serialize a multi-DLEq proof to a `Vec<u8>`.
  #[cfg(feature = "serialize")]
  pub fn serialize(&self) -> Vec<u8> {
    let mut res = vec![];
    self.write(&mut res).unwrap();
    res
  }
}