mirage/eqaf

Expose `select`

cfcs opened this issue · 5 comments

cfcs commented

I think it might be useful to expose select functions like

  val one_if_not_zero : int -> int
  (** [one_if_not_zero n] is a constant-time version of [if n <> 0 then 1 else 0].
       This is functionally equivalent to [!!n] in the C programming language. *)

  val select_int : int -> int -> int
  (** [select_int choose_b   a   b] is [a] if [choose_b = 0] and [b] otherwise.
       The comparison is constant-time and it should not be possible for a
       measuring adversary to determine anything about
       the values of [choose_b], [a], or [b].
   *)

Implementation could look conceptually look like:

let [@inline] one_if_not_zero n =
  (* Is there a faster way to do this in OCaml?
     Essentially we are casting an integer to a bool,
     then casting back to integer. It feels like maybe the Stdlib
     should provide something like that?
  *)
  let minus_one_or_less = (n asr Sys.int_size) in
  let one_or_more = (-n) asr Sys.int_size in
  (minus_one_or_less lor one_or_more) land 1
  
let [@inline] select_int choose_b a b =
  let one_if_choose_b = one_if_not_zero choose_b in
  let one_if_choose_a = one_if_choose_b lxor 1 in (* 0 if choose_b *)
  (b * one_if_choose_b) lor (a * one_if_choose_a)

I believe we can refactor a bunch of our existing code to use these primitives if we added them to the library.

Code that prompted suggestion

Edit: It looks like more recent versions of OCaml provide at least Stdlib.Bool.to_int, which at least on Godbolt compiles to a branched version: https://ocaml.godbolt.org/z/kbygwA

Not sure why that crappy code gets emitted, I would have expected a branch-free implementation. Maybe something to ask a compiler hacker about...

pextq %rax, %rax, %rax ; collect set bits (incl. GC bit)
andq $2                ; only care about GC + first set bit (if any)

or at least

dec %0                      ; 0 if only GC bit was set 
test %rax, %rax             ; set ZF flag (because it's zero)
setnz %al                   ; put ZF in lowest bit
leaq 1(%rax, %rax, 1), %rax ; shift ZF by 1, add GC bit
; %rax is 1 (=0 + GC bit) or 3 (= 1<<1 + GC bit)

I can look that next week, thanks 👍

cfcs commented

Hmm, had a bit of a brainfart earlier. Here's a better proposal I believe.
Using these primitives we can implement ct_find_uint8 from Nocrypto like this:

let [@inline] minus_one_or_less n =
   n lsr (Sys.int_size-1)

let [@inline] one_if_not_zero n =
  (* Is there a faster way to do this in OCaml?
     Essentially we are casting an integer to a bool,
     then casting back to integer. It feels like maybe the Stdlib
     should provide something like that?
     Equivalent to C: [!!n]
  *)
  minus_one_or_less ((-n) lor n)

let [@inline] zero_if_not_zero n =
  (* equivalent to C: [!n] *)
  one_if_not_zero n -1

let [@inline] select_int choose_b a b =
  let mask = ((-choose_b) lor choose_b) asr (Sys.int_size) in
  (a land (lnot mask)) lor (b land mask)
 
let ct_find_uint8
    ~off (* offset needs to be explicitly passed in;
               if we use ?(off=0) we get a branch*)
    ~f cs =
    (* if off < 0, off = 0 *)
    (* if off >= len cs, return -1 *)
    (* ~f is a fun (int -> int) that must return 0 when the predicate is met.
       we flip the value to non-zero internally, but using 0 is nice because
       then it works with [compare] and friends.*)

    (* to prevent leaking [off] we always scan entire Cstruct.t *)
    let i = ref (String.length cs -1)
    and found_idx = ref ~-1 in
    while !i >= 0 do
    let byte = Char.code (String.unsafe_get cs !i)in
      (* Constant-time version of:
        if off <= i && (f byte = 0)
        then !i
        else !found_idx
       *)
    (* rfind, note that offset is from beginning of string still,
        may want to flip it. additionally found_idx should start at 0,
        and you need pred !found_idx at the end.
        maybe just going the other way in the loop
        makes for faster code. *)
    (*found_idx :=
      select_int
        ( ((!i -off) land min_int)
          lor !found_idx (* ignore if we already have one *)
          lor f byte)
        (succ !i)
        !found_idx ;*)
    (* lfind: *)
    found_idx :=
      select_int ( ((!i - off) land min_int)
                   lor f byte)
        !i (* record this index since [0 = f byte && !i-off >= 0]*)
        !found_idx ;
    decr i;
    done ; !found_idx

I'm just aware to see if between the eqaf module and the mirage-crypto module, inlining still exists (cross-module optimization). If it is the case, expose these functions should be helpful indeed.

About your comment:

Is there a faster way to do this in OCaml?
Essentially we are casting an integer to a bool,
then casting back to integer. It feels like maybe the Stdlib
should provide something like that?
Equivalent to C: [!!n]

bool and int has the same representation in the OCaml runtime. But I don't clearly understand your goal when you write (-n lor n).

cfcs commented

@dinosaure That is indeed interesting. My assumption was that that was the case, but I have not tested, and I don't know if there's additional precautions to be taken, e.g. do you need inlining qualifiers in the mli as well as the implementation?

(-n lor n) is basically about setting the sign bit iff n <> 0.
The minus_one_or_less shifts by everything but the sign bit, so if you had the sign bit set, you now have 1, and otherwise you have 0.
If we had simply n lsr (Sys.int_size-1) we would get 0 for positive integers and 1for negative integers (where the sign bit is set).
If we do (-n) lsr (Sys.int_size-1) the function will return 1 for positive numbers and 0 for negative numbers.
The integer ~- operation in two's complement systems is basically let (~-) n = (lnot n) + 1 just like you can define Stdlib.succ in terms of let succ n = (-n) - lnot 0.
What we are taking advantage of with this trick is that positive numbers grow "up" from 0000 (0) towards 0111 (max_int) and negative numbers grow "down" from 1111 (-1) towards 1000 (min_int). This means there are only two numbers for which ((lnot n) +1) lor n remains identical to n: min_int and 0:

1000 (* lnot: *)
0111 (* +1: *)
1000

0000 (* lnot: )
1111 (* + 1: *)
0000

Zero is also the only number for which this operation does not set the sign bit if the sign bit wasn't already set (hence the lor). I played around with a few different ways to achieve something similar, and this was the best I could do without resorting to inline assembly or Obj.magic. The lor operation is nice since it lets the compiler not add an orq $1 instruction at the end (for GC/tag bit) since any tagged number (like our n) lor'ed with anything else (like -n) will keep having the tag bit because of the unchanged n. It seems like the OCaml compiler has this optimization when I look at resulting code on Godbolt.

I introduced one_if_not_zero to be able to multiply by 1 or 0 respectively. We could have chosen an API where the user must supply either 1 or 0, but that feels like it will be error-prone, and with this API it's very similar to what is available to people writing timing-sensitive code in C, so hopefully easier to port code.
Multiplication generated really ugly code, so I replaced it with the (a & (~mask)) | (b & mask) instead.
This does not generate ideal code either since the compiler does not seem to be able to concentrate enough to see that (a land (lnot mask)) lor (b land mask) will always have a tag bit: No matter if mask has tag bit, lnot mask will have the opposite. Both a and b come from other scope, so they will have both tag bit. Since we do lor, either of them having a tag bit is enough that we don't have to explicitly set one. It's the best I could come up with from pure OCaml without branching on the values though.

Perhaps my explanation here isn't the best, if I lost you somewhere in this stream of consciousness, let me know and I'll try to rephrase it!

Close by #19