Subtraction is right associative
Closed this issue · 1 comments
sgpthomas commented
Notice the indices to the r
array. In fuse r[k-i-1]
results in r[k-(i-1)]
.
This is the fuse program I discovered this in:
// BEGIN macro definitions
define(N, 8)
define(DATATYPE, float)
// END macro definitions
import "printer.cpp" {
def extern print_vec(v: DATATYPE[N]);
}
decl r: DATATYPE[N];
decl y: DATATYPE[N];
// XXX(rachit): This should be a local array.
decl z: DATATYPE[N];
let alpha = 0.0 - r[0];
let beta = 1.0;
---
y[0] := 0.0 - r[0];
---
for (let k = 1..N) {
beta := (1.0 - alpha * alpha) * beta;
let sum = 0.0;
let i = 0;
while (i < k) {
sum := sum + r[k-i-1] * y[i];
// Update loop counter
i := i + 1;
}
---
alpha := 0.0 - (r[k] + sum)/beta;
i := 0;
while(i < k) {
let y_i = y[i];
---
z[i] := y_i + alpha * y[k-i-1];
// Update loop counter
i := i + 1;
}
---
i := 0;
while(i < k) {
y[i] := z[i];
// Update loop counter
i := i + 1;
}
---
y[k] := alpha;
}
---
print_vec(y);
print_vec(z);
This is the resulting C++ generated from the run backend:
#include "parser.cpp"
#include "printer.cpp"
/***************** Parse helpers ******************/
/***************************************************/
void kernel(vector<float> r, vector<float> y, vector<float> z) {
float alpha = (0.0 - r[0]);
float beta = 1.0;
//---
y[0] = (0.0 - r[0]);
//---
for(int k = 1; k < 8; k++) {
beta = ((1.0 - (alpha * alpha)) * beta);
float sum = 0.0;
int i = 0;
while((i < k)) {
sum = (sum + (r[(k - (i - 1))] * y[i]));
i = (i + 1);
}
//---
alpha = (0.0 - ((r[k] + sum) / beta));
i = 0;
while((i < k)) {
float y_i = y[i];
//---
z[i] = (y_i + (alpha * y[(k - (i - 1))]));
i = (i + 1);
}
//---
i = 0;
while((i < k)) {
y[i] = z[i];
i = (i + 1);
}
//---
y[k] = alpha;
}
//---
print_vec(y);
print_vec(z);
}
int main(int argc, char** argv) {
using namespace flattening;
auto v = parse_data(argc, argv);;
auto r = get_arg<n_dim_vec_t<float, 1>>("r", "float[]", v);
auto y = get_arg<n_dim_vec_t<float, 1>>("y", "float[]", v);
auto z = get_arg<n_dim_vec_t<float, 1>>("z", "float[]", v);
kernel(r, y, z);
return 0;
}
rachitnigam commented
Thanks for the issue! I think you can just move around the combinators in parser.scala if you want to try making it work.