JonathanSalwan/Triton

Lifting to LLVM with optimizations

JonathanSalwan opened this issue ยท 3 comments

Allow the user to apply LLVM optimizations (-O3, -Oz) when lifting to LLVM.

>>> from triton import *
>>> 
>>> ctx = TritonContext(ARCH.X86_64)
>>> ast = ctx.getAstContext()
>>> 
>>> x = ast.variable(ctx.newSymbolicVariable(8))
>>> y = ast.variable(ctx.newSymbolicVariable(8))
>>> 
>>> print(ctx.liftToLLVM((x & ~y) | (~x & y), fname='mba'))
; ModuleID = 'tritonModule'
source_filename = "tritonModule"

define i8 @mba(i8 %SymVar_0, i8 %SymVar_1) {
entry:
  %0 = xor i8 %SymVar_0, -1
  %1 = and i8 %0, %SymVar_1
  %2 = xor i8 %SymVar_1, -1
  %3 = and i8 %SymVar_0, %2
  %4 = or i8 %3, %1
  ret i8 %4
}

>>> print(ctx.liftToLLVM((x & ~y) | (~x & y), fname='mba_opti', optimize=True))
; ModuleID = 'tritonModule'
source_filename = "tritonModule"

; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i8 @mba_opti(i8 %SymVar_0, i8 %SymVar_1) local_unnamed_addr #0 {
entry:
  %0 = xor i8 %SymVar_1, %SymVar_0
  ret i8 %0
}

attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn }

>>>

Wow! Someday we'll try it during DSE. I have an idea to optimize all path constraints from the path predicate before solving. Optimizing all formulas can bring lots of overhead while optimized path constraints will be used quite often.

Do not hesitate to give feedback if so!

In addition to the lifting + optimization. It can be useful to get back a simplified triton node using LLVM. So now the ctx.simplify function supports a new argument: llvm.

>>> from triton import *

>>> ctx = TritonContext(ARCH.X86_64)
>>> ast = ctx.getAstContext()

>>> x = ast.variable(ctx.newSymbolicVariable(8, 'x'))
>>> y = ast.variable(ctx.newSymbolicVariable(8, 'y'))

>>> n = (x & ~y) | (~x & y)

>>> # orginal node
>>> print(n)
(bvor (bvand x (bvnot y)) (bvand (bvnot x) y))

>>>  # simpl with solver
>>> print(ctx.simplify(n, solver=True))
(bvor (bvnot (bvor (bvnot x) y)) (bvnot (bvor x (bvnot y))))

>>> # simpl with llvm
>>> print(ctx.simplify(n, llvm=True))
(bvxor y x)