Support for large finite fields
Opened this issue · 3 comments
As far as I understand, the focus of this project is ECC, even though other uses of finite field arithmetic are supported (Poly1305). Is it easily possible to support finite fields large enough that DLog is hard in them, e.g., primes of size ~ 3072 bits?
I know the cool kids use ECC nowadays but finite field DLog is not dead either. I'm specifically asking because Bitcoin is looking into MuHash [1] for incremental hashing for sets and finite field DLog is slightly more efficient for this use case in this context [2,3]. I wonder if fiat-crypto could be useful here. Specifically the linked PR implements finite field arithmetic for 2^3072 - 1103717 but since this a new feature, the choice of the prime is not set in stone.
[1] http://cseweb.ucsd.edu/~Mihir/papers/inchash.pdf
[2] https://lists.linuxfoundation.org/pipermail/bitcoin-dev/2017-May/014337.html
[3] bitcoin/bitcoin#19055
Unfortunately, the synthesis time of the current code seems to be exponential in the number of limbs (for WBW Montgomery) or at least cubic (for unsaturated Solinas), for unclear reasons. Our 12-limb example takes around 8 minutes, and just a couple more limbs brings us up to an hour. 3072-bit primes on 64-bit machines would involve 48 limbs, which is much more than the code can handle at this point. There are two issues, here:
- (more fundamental) It's not clear that you really want to unroll all of the loops when there are so many limbs (it blows up code size and might in fact decrease performance), and fiat-crypto currently does not support loops at the field-arithmetic primitive level.
- The aforementioned super-linear factor in synthesis time. We'd love help tracking this down, if you (or anyone others) are so inclined. Here's some code for digging into the performance issues. Note that this code isn't the code that we actually run (in particular, it doesn't do any arithmetic simplification), but I believe the performance bottlenecks in our actual code already show up in this code.
Super-linear time in positional->associational->positional:
Code
Require Import Coq.Lists.List Coq.ZArith.ZArith.
Require Import Crypto.ArithmeticCPS.Core.
Require Import Crypto.ArithmeticCPS.ModOps.
Import ListNotations.
Open Scope Z_scope.
Open Scope nat_scope.
Open Scope list_scope.
Module RT_ExtraAx := Core.RT_Extra Core.RuntimeAxioms.
Module ModOpsAx := ModOps.ModOps Core.RuntimeAxioms.
Module PositionalAx := ArithmeticCPS.Core.Positional Core.RuntimeAxioms.
Definition test (nlimbs : nat) (f : list Z) : list Z
:= let wt := weight 64 1 in
let p := PositionalAx.to_associational wt nlimbs (RT_ExtraAx.expand_list 0%Z f nlimbs) in
PositionalAx.from_associational_cps wt nlimbs p _ id.
Import RuntimeAxioms.
Ltac timetest n := idtac "n=" n; time (idtac; let x := (eval vm_compute in (test n)) in idtac).
Ltac timetest_up_to n :=
lazymatch n with
| O => idtac
| S ?n => timetest_up_to n
end;
timetest n.
Goal True. timetest_up_to 35. Abort.
Table of timing as a function of limbs
# limbs | time (seconds) |
---|---|
0 | 0. |
1 | 0. |
2 | 0. |
3 | 0. |
4 | 0.001 |
5 | 0.004 |
6 | 0.007 |
7 | 0.013 |
8 | 0.022 |
9 | 0.035 |
10 | 0.052 |
11 | 0.074 |
12 | 0.116 |
13 | 0.141 |
14 | 0.198 |
15 | 0.247 |
16 | 0.356 |
17 | 0.405 |
18 | 0.522 |
19 | 0.631 |
20 | 0.776 |
21 | 0.944 |
22 | 1.174 |
23 | 1.377 |
24 | 1.655 |
25 | 2.043 |
26 | 2.279 |
27 | 2.633 |
28 | 3.095 |
29 | 3.524 |
30 | 3.968 |
31 | 4.534 |
32 | 5.16 |
33 | 5.845 |
34 | 6.594 |
35 | 7.398 |
Super-linear time in chained-carries:
Code
Require Import Coq.Lists.List Coq.ZArith.ZArith.
Require Import Crypto.ArithmeticCPS.Core.
Require Import Crypto.ArithmeticCPS.ModOps.
Require Crypto.UnsaturatedSolinasHeuristics.
Import ListNotations.
Open Scope Z_scope.
Open Scope nat_scope.
Open Scope list_scope.
Module RT_ExtraAx := Core.RT_Extra Core.RuntimeAxioms.
Module ModOpsAx := ModOps.ModOps Core.RuntimeAxioms.
Module PositionalAx := ArithmeticCPS.Core.Positional Core.RuntimeAxioms.
Definition test2 (nlimbs : nat) (f : list Z) : list Z
:= let wt := weight 64 1 in
let s := (2^(64 * Z.of_nat nlimbs - 1))%Z in
let c := [(1, 1)]%Z in
let idxs := UnsaturatedSolinasHeuristics.carry_chains nlimbs s c in
let p := PositionalAx.chained_carries_cps wt nlimbs s c (RT_ExtraAx.expand_list 0%Z f nlimbs) idxs _ id in
p.
Import RuntimeAxioms.
Ltac timetest2 n :=
idtac "n=" n;
restart_timer;
let x := (eval vm_compute in (test2 n)) in
finish_timing ("Tactic call vm_compute");
restart_timer;
let x := (eval vm_compute in x) in
finish_timing ("Tactic call noop-vm_compute").
Ltac timetest2_up_to n :=
lazymatch n with
| O => idtac
| S ?n => timetest2_up_to n
end;
timetest2 n.
Goal True. timetest2_up_to 15. Abort.
Table of timing as a function of limbs
# limbs | time (s) | time to walk output code (s) |
---|---|---|
0 | 0. | 0. |
1 | 0.001 | 0. |
2 | 0.004 | 0. |
3 | 0.017 | 0.001 |
4 | 0.048 | 0.02 |
5 | 0.107 | 0.001 |
6 | 0.216 | 0.032 |
7 | 0.401 | 0.001 |
8 | 0.703 | 0.003 |
9 | 1.153 | 0.004 |
10 | 1.803 | 0.008 |
11 | 2.742 | 0.005 |
12 | 4.017 | 0.005 |
13 | 5.659 | 0.006 |
14 | 7.939 | 0.007 |
15 | 11.005 | 0.009 |
Super-linear time in WBW Montgomery multiplication:
Code
Require Import Coq.Lists.List Coq.ZArith.ZArith.
Require Import Crypto.ArithmeticCPS.Core.
Require Import Crypto.ArithmeticCPS.ModOps.
Require Import Crypto.ArithmeticCPS.WordByWordMontgomery.
Require Import Crypto.Util.ZUtil.ModInv.
Import ListNotations.
Open Scope Z_scope.
Open Scope nat_scope.
Open Scope list_scope.
Module Import RT_ExtraAx := Core.RT_Extra Core.RuntimeAxioms.
Module ModOpsAx := ModOps.ModOps Core.RuntimeAxioms.
Module PositionalAx := ArithmeticCPS.Core.Positional Core.RuntimeAxioms.
Module WordByWordMontgomeryAx := ArithmeticCPS.WordByWordMontgomery.WordByWordMontgomery Core.RuntimeAxioms.
Definition test3 (nlimbs : nat) (f g : list Z) : list Z
:= let wt := weight 64 1 in
let m := (2^(64 * Z.of_nat nlimbs - 1) - 1)%Z in
let m' := match Z.modinv (-m) (2^64) with
| Some m' => m'
| None => 0%Z
end in
WordByWordMontgomeryAx.mulmod 64 nlimbs m m' (expand_list 0%Z f nlimbs) (expand_list 0%Z g nlimbs).
Import RuntimeAxioms.
Ltac timetest3 n :=
idtac "n=" n;
restart_timer;
let x := (eval vm_compute in (test3 n)) in
finish_timing ("Tactic call vm_compute");
restart_timer;
let x := (eval vm_compute in x) in
finish_timing ("Tactic call noop-vm_compute").
Ltac timetest3_up_to n :=
lazymatch n with
| O => idtac
| S ?n => timetest3_up_to n
end;
timetest3 n.
Goal True. timetest3_up_to 9. Abort.
Table of timing as a function of limbs
# limbs | time (s) | time to walk output code (s) |
---|---|---|
0 | 0. | 0. |
1 | 0.002 | 0.001 |
2 | 0.022 | 0.007 |
3 | 0.157 | 0.065 |
4 | 0.608 | 0.124 |
5 | 1.876 | 0.277 |
6 | 4.954 | 0.56 |
7 | 10.851 | 1.254 |
8 | 22.151 | 2.249 |
9 | 41.965 | 4.421 |
N.B. This code runs with commit 4c54bbb of our project. I expect it'll continue working for some time, but I'm leaving this commit hash here so it's easy to dig up a working version if we need to.
Some additional stats on WBW Montgomery multiplication (without arithmetic simplification):
Code
Require Import Coq.Lists.List Coq.ZArith.ZArith.
Require Import Crypto.ArithmeticCPS.Core.
Require Import Crypto.ArithmeticCPS.ModOps.
Require Import Crypto.ArithmeticCPS.WordByWordMontgomery.
Require Import Crypto.Util.ZUtil.ModInv.
Import ListNotations.
Open Scope Z_scope.
Open Scope nat_scope.
Open Scope list_scope.
Module Import RT_ExtraAx := Core.RT_Extra Core.RuntimeAxioms.
Module ModOpsAx := ModOps.ModOps Core.RuntimeAxioms.
Module PositionalAx := ArithmeticCPS.Core.Positional Core.RuntimeAxioms.
Module WordByWordMontgomeryAx := ArithmeticCPS.WordByWordMontgomery.WordByWordMontgomery Core.RuntimeAxioms.
Definition test3 (nlimbs : nat) (f g : list Z) : list Z
:= let wt := weight 64 1 in
let m := (2^(64 * Z.of_nat nlimbs - 1) - 1)%Z in
let m' := match Z.modinv (-m) (2^64) with
| Some m' => m'
| None => 0%Z
end in
WordByWordMontgomeryAx.mulmod 64 nlimbs m m' (expand_list 0%Z f nlimbs) (expand_list 0%Z g nlimbs).
Import RuntimeAxioms.
Inductive tags := FORALL | FUN | LET | MATCHPAIR.
Inductive HList := Nil | Cons {A} (x : A) (y : nat) (xs : HList).
Fixpoint count (v : HList) : Z
:= match v with
| Nil => 0
| Cons A x n xs => Z.of_nat n + count xs
end%Z.
Delimit Scope hlist_scope with hlist.
Bind Scope hlist_scope with HList.
Notation "( x , y ) :: z" := (Cons x y z) (at level 60, right associativity) : hlist_scope.
Ltac count_size v acc :=
let __ := match goal with _ => idtac acc end in
let inc_count acc tag
:= lazymatch acc with
| context ACC[Cons tag ?c0]
=> let acc := context ACC[Cons tag (S c0)] in
acc
| _ => constr:(Cons tag 1%nat acc)
end in
let handle_fun v acc :=
lazymatch v with
| fun x : ?A => ?f
=> let acc := count_size A acc in
let f' := fresh in
let acc' := fresh in
lazymatch
constr:(
fun x : A
=> match f, acc return _ with
| f', acc'
=> ltac:(let f := (eval cbv delta [f'] in f') in
let acc := (eval cbv delta [acc'] in acc') in
clear f' acc';
let acc := count_size f acc in
refine acc)
end) with
| fun _ => ?acc => acc
| ?e => fail 0 "cannot eliminate functional dependencies of" e
end
| let x : ?A := ?v in ?f
=> let acc := count_size A acc in
let acc := count_size v acc in
let f' := fresh in
let acc' := fresh in
lazymatch
constr:(
let x : A := v in
match f, acc return _ with
| f', acc'
=> ltac:(let f := (eval cbv delta [f'] in f') in
let acc := (eval cbv delta [acc'] in acc') in
clear f' acc';
let acc := count_size f acc in
refine acc)
end) with
| let _ := _ in ?acc => acc
| ?e => fail 0 "cannot eliminate functional dependencies of" e
end
end in
lazymatch v with
| fun x => @?f x => let acc := inc_count acc FUN in handle_fun f acc
| forall x, @?f x => let acc := inc_count acc FORALL in handle_fun f acc
| let x := _ in _ => let acc := inc_count acc LET in handle_fun v acc
| ?f ?x => let acc := count_size f acc in
let acc := count_size x acc in
acc
| match ?x with pair a b => @?f a b end
=> let acc := inc_count acc MATCHPAIR in
let acc := count_size x acc in
handle_fun f acc
| ?v => let iv := constr:(ltac:(match goal with
| _ => is_var v; exact true
| _ => exact false
end)) in
lazymatch iv with
| true => acc
| false => inc_count acc v
end
end.
Axiom ax : forall {T}, T.
Ltac strip_fun v :=
lazymatch v with
| fun x : ?A => ?f
=> let f := lazymatch type of v with
| forall x : ?A, ?P => constr:(match @ax A as x return P with x => f end)
end in
strip_fun f
| ?v => v
end.
Ltac count_size_dlet v :=
lazymatch v with
| @Let_In ?A (fun _ => ?T) ?v (fun x => ?f)
=> let f := constr:(match @ax A return T with x => f end) in
let n := count_size_dlet f in
uconstr:(S n)
| @Let_In ?A ?P ?v (fun x => ?f)
=> let f := constr:(match @ax A as b return P b with x => f end) in
let n := count_size_dlet f in
uconstr:(S n)
| fun x : ?A => ?f
=> let f := lazymatch type of v with
| ?A -> ?T => constr:(match @ax A return T with x => f end)
| forall x : ?A, ?P => constr:(match @ax A as x return P with x => f end)
end in
let n := count_size_dlet f in
n
| _ => constr:(O)
end.
Require Ltac2.Ltac2.
Require Ltac2.Array.
Require Ltac2.Constr.
Require Ltac2.Ltac1.
Require Ltac2.Message.
Require Ltac2.Int.
Require Ltac2.Control.
Module Import WithLtac2.
Import Ltac2.Ltac2.
Import Ltac2.Constr.
Ltac2 rec count_dlet (v : constr) :=
match Unsafe.kind v with
| Unsafe.App _Let_In args
=> match Int.lt 3 (Array.length args) with
| true => Int.add 1 (count_dlet (Array.get args 3))
| false => 0
end
| Unsafe.Lambda _ _ body
=> count_dlet body
| _ => 0
end.
Ltac2 print_count_dlet (v : constr option) :=
match v with
| Some v
=> let n := count_dlet v in
Message.print (Message.concat (Message.of_string "# lets = ") (Message.of_int n))
| None => Control.zero Not_found
end.
Ltac count_dlet_ltac2 := ltac2:(v |- print_count_dlet (Ltac1.to_constr v)).
End WithLtac2.
Ltac timetest3 n :=
idtac "n=" n;
restart_timer;
let x := (eval vm_compute in (test3 n)) in
finish_timing ("Tactic call vm_compute");
restart_timer;
let x := (eval vm_compute in x) in
finish_timing ("Tactic call noop-vm_compute");
count_dlet_ltac2 x.
Ltac timetest3_up_to n :=
lazymatch n with
| O => idtac
| S ?n => timetest3_up_to n
end;
timetest3 n.
Goal True. timetest3_up_to 9. Abort.
And to generate the number of lines of code in the output:
for i in $(seq 1 10); do echo "$i$(printf '\t')$(./src/ExtractionOCaml/word_by_word_montgomery p 64 "2^($i*64)-1" mul --no-primitives | wc -l)"; done
Table
# limbs | time (s) | time to walk output code (s) | # of let s |
# of loc in C |
---|---|---|---|---|
0 | 0 | 0 | 2 | 0 |
1 | 0.002 | 0.001 | 30 | 69 |
2 | 0.022 | 0.007 | 133 | 138 |
3 | 0.157 | 0.065 | 433 | 239 |
4 | 0.608 | 0.124 | 1124 | 376 |
5 | 1.876 | 0.277 | 2392 | 549 |
6 | 4.954 | 0.56 | 4511 | 758 |
7 | 10.851 | 1.254 | 7793 | 1003 |
8 | 22.151 | 2.249 | 12598 | 1284 |
9 | 41.965 | 4.421 | 19334 | 1601 |
Compute test3 0.
= fun _ _ : list Z =>
dlet a : Z := 0%Z in
dlet _ : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 0 0
(- a)%RT in
[]
: list Z -> list Z -> list Z
Set Printing Width 200. Compute test3 1.
= fun f g : list Z =>
dlet a : Z := runtime_nth_default 0%Z f 0 in
dlet a0 : Z * Z := RT_Z.mul_split 18446744073709551616 a (runtime_nth_default 0%Z g 0) in
dlet a1 : Z := let (_, H) := a0 in H in
dlet a2 : Z := let (H, _) := a0 in H in
dlet a3 : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 0 a2 0 in
dlet a4 : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 (let (_, H) := a3 in H) a1 0 in
dlet a5 : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 0 0 (let (H, _) := a3 in H) in
dlet a6 : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 (let (_, H) := a5 in H) 0 (let (H, _) := a4 in H) in
dlet a7 : Z := let (H, _) := a5 in H in
dlet a8 : Z := let (H, _) := a6 in H in
dlet a9 : Z := (0%Z + (let (_, H) := a6 in H))%RT in
dlet a10 : Z := a7 in
dlet a11 : Z := let (H, _) := RT_Z.mul_split 18446744073709551616 a10 9223372036854775809 in H in
dlet a12 : Z * Z := RT_Z.mul_split 18446744073709551616 a11 9223372036854775807 in
dlet a13 : Z := let (_, H) := a12 in H in
dlet a14 : Z := let (H, _) := a12 in H in
dlet a15 : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 0 a14 0 in
dlet a16 : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 (let (_, H) := a15 in H) a13 0 in
dlet a17 : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 0 a7 (let (H, _) := a15 in H) in
dlet a18 : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 (let (_, H) := a17 in H) a8 (let (H, _) := a16 in H) in
dlet a19 : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 (let (_, H) := a18 in H) a9 0 in
dlet _ : Z := let (H, _) := a17 in H in
dlet a21 : Z := let (H, _) := a18 in H in
dlet a22 : Z := let (H, _) := a19 in H in
dlet a23 : Z := a21 in
dlet a24 : Z := a22 in
dlet a25 : Z := 0%Z in
dlet a26 : Z := 9223372036854775807%Z in
dlet a27 : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 0 a23 (- a26)%RT in
dlet a28 : Z * Z := RT_Z.add_with_get_carry_full 18446744073709551616 (let (_, H) := a27 in H) a24 (- a25)%RT in
[RT_Z.zselect (- (0%Z + (let (_, H) := a28 in H)))%RT (let (H, _) := a27 in H) a23]
: list Z -> list Z -> list Z
This is almost perfectly quartic, as we can see on a log-log plot:
The number of lines of code in the generated C, by contrast, is almost perfectly quadratic:
Unfortunately, the synthesis time of the current code seems to be exponential in the number of limbs (for WBW Montgomery) or at least cubic (for unsaturated Solinas), for unclear reasons. Our 12-limb example takes around 8 minutes, and just a couple more limbs brings us up to an hour. 3072-bit primes on 64-bit machines would involve 48 limbs, which is much more than the code can handle at this point.
I see. Thanks for the quick and detailed response. I think then it's obviously then it's currently not interesting for us but I'll watch the issue of course. Unfortunately I don't have the resources to help.