cucapra/dahlia

Subtraction is right associative

Closed this issue · 1 comments

Notice the indices to the r array. In fuse r[k-i-1] results in r[k-(i-1)].
This is the fuse program I discovered this in:

// BEGIN macro definitions

define(N, 8)
define(DATATYPE, float)

// END macro definitions
import "printer.cpp" {
  def extern print_vec(v: DATATYPE[N]);
}

decl r: DATATYPE[N];
decl y: DATATYPE[N];

// XXX(rachit): This should be a local array.
decl z: DATATYPE[N];

let alpha = 0.0 - r[0];
let beta = 1.0;
---
y[0] := 0.0 - r[0];
---
for (let k = 1..N) {
  beta := (1.0 - alpha * alpha) * beta;

  let sum = 0.0;
  let i = 0;
  while (i < k) {
    sum := sum + r[k-i-1] * y[i];
    // Update loop counter
    i := i + 1;
  }
  ---
  alpha := 0.0 - (r[k] + sum)/beta;

  i := 0;
  while(i < k) {
    let y_i = y[i];
    ---
    z[i] := y_i + alpha * y[k-i-1];
    // Update loop counter
    i := i + 1;
  }
  ---
  i := 0;
  while(i < k) {
    y[i] := z[i];
    // Update loop counter
    i := i + 1;
  }
  ---
  y[k] := alpha;
}

---
print_vec(y);
print_vec(z);

This is the resulting C++ generated from the run backend:

#include "parser.cpp"
#include "printer.cpp"
/***************** Parse helpers  ******************/

/***************************************************/
void kernel(vector<float> r, vector<float> y, vector<float> z) {
  
  float alpha = (0.0 - r[0]);
  float beta = 1.0;
  //---
  y[0] = (0.0 - r[0]);
  //---
  for(int k = 1; k < 8; k++) {
    beta = ((1.0 - (alpha * alpha)) * beta);
    float sum = 0.0;
    int i = 0;
    while((i < k)) {
      sum = (sum + (r[(k - (i - 1))] * y[i]));
      i = (i + 1);
    }
    //---
    alpha = (0.0 - ((r[k] + sum) / beta));
    i = 0;
    while((i < k)) {
      float y_i = y[i];
      //---
      z[i] = (y_i + (alpha * y[(k - (i - 1))]));
      i = (i + 1);
    }
    //---
    i = 0;
    while((i < k)) {
      y[i] = z[i];
      i = (i + 1);
    }
    //---
    y[k] = alpha;
  }
  //---
  print_vec(y);
  print_vec(z);
}
int main(int argc, char** argv) {
  using namespace flattening;
  auto v = parse_data(argc, argv);;
  auto r = get_arg<n_dim_vec_t<float, 1>>("r", "float[]", v);
  auto y = get_arg<n_dim_vec_t<float, 1>>("y", "float[]", v);
  auto z = get_arg<n_dim_vec_t<float, 1>>("z", "float[]", v);
  kernel(r, y, z);
  return 0;
}

Thanks for the issue! I think you can just move around the combinators in parser.scala if you want to try making it work.