redwoodresearch/mlab

w1d2: some solutions make assumptions about strides

Closed this issue · 1 comments

For example:

def strided_trace(a):
  N, _ = a.shape
  a_strided = torch.as_strided(a, size=(N,), stride=(N+1,))
  return torch.sum(a_strided)

should be something like:

def strided_trace(a):
  N, _ = a.shape
  ns, ms = a.stride()
  a_strided = torch.as_strided(a, size=(N,), stride=(ns + ms,))
  return torch.sum(a_strided)

I think strided_matmul has a similar issue? It'll bite you if the matrix you pass in has "weird" strides. It may be useful to write a testcase that feeds in a weird-strided input.

Kiv commented

Thanks Ben! You're right about the issue and I've added test cases for this in the new repo for MLAB2.