mit-plv/fiat-crypto

Difficulty adding a rewrite rule

OwenConoly opened this issue · 9 comments

I'm trying to add a Karatsuba template to fiat-crypto, but bounds analysis isn't smart enough to get the proper bounds for Karatsuba. In particular, I have written a function adk_mul (stands for 'arbitrary-degree karatsuba'), the outputs of which I would like to put some tight bounds on. My idea to accomplish this is as follows:

  1. Add adk_mul as a basic identifier - one of the things in the all_ident_named_interped list in Language/IdentifierParameters.v
  2. Add some logic in AbstractInterpretation/ZRange.v to properly bound the outputs of the adk_mul function.
  3. Add a rule in Rewriter/Rules.v that unfolds the adk_mul identifier. The idea is that this rule will be triggered after the correct bounds are already established for the outputs of the adk_mul function.

I'm struggling with accomplishing part 3. When I try to add this rule, I get a lengthy error message which is mysterious to me.
Note that the adk_mul and ident_adk_mul functions are defined in Arithmetic/ADK.v.

I tried to reduce the size of the error message and hopefully get something that's comprehensible. It turns out that even trying to add this rule (using a subexpression of the expression that occurred in the rule that I'm trying to add) causes this error message. This error message is much shorter, but I still have no idea what's going on there.

Is there a way to fix these issues, and get the rewrite rule to work? Or should I take a different approach entirely to getting the right bounds for karatsuba multiplication?

This should work in theory, but this part of the code is ad-hoc and unsatisfactory. I have some debugging suggestions at #1609 (comment). If you're down to implement that suggestion, you should get better error messages, and then I can plausibly direct you a bit more. But it looks like maybe the problem is that the machinery is unfolding something it shouldn't be?

Regarding the shorter error message, rules like (Z.to_nat (1 + ((Z.of_nat i + 1)/2 - 1) - Z.of_nat (i - (n - 1))%nat)%Z) aren't supported because the rewriter can't handle non-linear (duplicate) occurrences of non-literals. (Granted, "can't unify" is a bad error message, and if I had the time I'd make it give a better error message on these.)

I see - I was getting confused about what the problem was here, because I made the incorrect assumption that if a was a subexpression of b, and the rewriter failed to handle a = a, then it couldn't handle b = b either. Now it's clear that I wasn't debugging this in a helpful way.

For context: I defined ident_adk_mul := adk_mul, and I added ident_adk_mul to the all_ident_named_interped list in Language/IdentifierParameters.v. Here is some information that's probably more useful than what I provided in my previous comment.

  1. I tried redefining adk_mul (n : nat) (x y : list Z) : list Z := []. When I did this, the rewriter happily ingested the rule forall n x y, ident_adk_mul n x y = adk_mul n x y.
  2. I tried redefining adk_mul (n : nat) (x y : list Z) : list Z := [6]. When I did this, the make_rewriter tactic failed in Arith.v, due to the presence of the rule forall n x y, ident_adk_mul n x y = adk_mul n x y. Here's the error message - I hope you can make some sense of it?
The error message:
Proving Rewriter_Interp...
Tactic call ran for 1.317 secs (1.307u,0.s) (success)
============================
WARNING: UNSOLVED GOAL:
max_const_val : Z
x1 :
(RewriteRules.Compile.value' false
   (type.base (base.type.type_base Compilers.nat)))
x2 : (API.interp_type (type.base (base.type.type_base Compilers.nat)))
H :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x1 x2)
x0 :
(RewriteRules.Compile.value' false
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
x3 :
(API.interp_type
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
H0 :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x0 x3)
x4 :
(RewriteRules.Compile.value' false
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
x5 :
(API.interp_type
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
H1 :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x4 x5)
============================
(expr.interp_related_gen (@Compilers.ident_interp)
   (fun t : API.type =>
    ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
      (@Compilers.ident_interp)) (#Compilers.ident_cons @ ###6%Z)
   (ADK.adk_mul x2 x3))
============================
============================
WARNING: UNSOLVED GOAL:
max_const_val : Z
x1 :
(RewriteRules.Compile.value' false
   (type.base (base.type.type_base Compilers.nat)))
x2 : (API.interp_type (type.base (base.type.type_base Compilers.nat)))
H :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x1 x2)
x0 :
(RewriteRules.Compile.value' false
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
x3 :
(API.interp_type
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
H0 :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x0 x3)
x4 :
(RewriteRules.Compile.value' false
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
x5 :
(API.interp_type
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
H1 :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x4 x5)
============================
(Compilers.ident_interp Compilers.ident_nil = x5)
============================
Tactic call ran for 0.442 secs (0.441u,0.s) (success)
WARNING: Remaining goal:
max_const_val : Z
x1 :
(RewriteRules.Compile.value' false
   (type.base (base.type.type_base Compilers.nat)))
x2 : (API.interp_type (type.base (base.type.type_base Compilers.nat)))
H :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x1 x2)
x0 :
(RewriteRules.Compile.value' false
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
x3 :
(API.interp_type
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
H0 :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x0 x3)
x4 :
(RewriteRules.Compile.value' false
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
x5 :
(API.interp_type
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
H1 :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x4 x5)
============================
(expr.interp_related_gen (@Compilers.ident_interp)
   (fun t : API.type =>
    ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
      (@Compilers.ident_interp)) (#Compilers.ident_cons @ ###6%Z)
   (ADK.adk_mul x2 x3))
WARNING: Remaining goal:
max_const_val : Z
x1 :
(RewriteRules.Compile.value' false
   (type.base (base.type.type_base Compilers.nat)))
x2 : (API.interp_type (type.base (base.type.type_base Compilers.nat)))
H :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x1 x2)
x0 :
(RewriteRules.Compile.value' false
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
x3 :
(API.interp_type
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
H0 :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x0 x3)
x4 :
(RewriteRules.Compile.value' false
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
x5 :
(API.interp_type
   (type.base (base.type.list (base.type.type_base Compilers.Z))))
H1 :
(ProofsCommon.Compilers.RewriteRules.Compile.value_interp_related
   (@Compilers.ident_interp) x4 x5)
============================
(Compilers.ident_interp Compilers.ident_nil = x5)
File "./src/Rewriter/Passes/Arith.v", line 23, characters 23-36:
Error: Proof is not complete.
  1. With adk_mul defined as in (2), I tried adding the rule ident_adk_mul = adk_mul. This also failed, with a similar error message as in (2).
  2. I tried adding the rule [6] = [6]. This worked fine.
  3. I tried adding the rule (fun (n : nat) (x y : list Z) => [6]) = (fun (n : nat) (x y : list Z) => [6]. This failed, complaining of an ILLEGAL_ABS_ON_LHS. I suppose lambdas on the LHS are bad.

It seems like the issue is related to the fact that I have an identifier that the rewriter recognizes on the left hand side, and a regular function on the right hand side. Do I need to somehow inform the rewriter that these things are equal?

2. I tried redefining adk_mul (n : nat) (x y : list Z) : list Z := [6]. When I did this, the make_rewriter tactic failed in Arith.v, due to the presence of the rule forall n x y, ident_adk_mul n x y = adk_mul n x y. Here's the error message - I hope you can make some sense of it?

For some reason, ident_adk_mul n x y seems to be getting reified on the LHS as [6] rather than as the identifier you added to the list of identifiers. Not sure what's causing this, but I suspect this is the root of the error.

Oh, no, nevermind, there must be a symmetry somewhere, I think the unfolding is expected.

I think what's happening is that you're seeing a failure to prove ADK.adk_mul x2 x3 equal to the interpretation of it's reification.

What if instead of adding forall n x y, ident_adk_mul n x y = adk_mul n x y. you add the rewrite rule forall n x y, ident_adk_mul n x y = ltac:(let rhs := eval cbv [adk_mul] in (adk_mul n x y) in exact rhs).?

Alternatively, stick something like Create HintDb reification_proofs discriminated. in https://github.com/mit-plv/rewriter/blob/c79bbc9148fb8386e188267c684a882d8634293b/src/Rewriter/Language/Pre.v
add a branch | progress repeat autounfold with reification_proofs to
https://github.com/mit-plv/rewriter/blob/c79bbc9148fb8386e188267c684a882d8634293b/src/Rewriter/Rewriter/ProofsCommonTactics.v#L501-503 right after assumption, and then add #[global] Hint Unfold adk_mul : reification_proofs. to Language/IdentifierParameters.v or something.

What if instead of adding forall n x y, ident_adk_mul n x y = adk_mul n x y. you add the rewrite rule forall n x y, ident_adk_mul n x y = ltac:(let rhs := eval cbv [adk_mul] in (adk_mul n x y) in exact rhs).?

This works, thank you! I assume your second suggestion works as well, since it seems equivalent.
Do you think the second solution would be better in some way, or should I just stick with the first one?

Never mind, your second suggestion does not appear to work. Not exactly that, anyway; perhaps something similar would work.
I'll just use the eval cbv solution then.

I don't know how to close this without making a comment, so here is a comment