w1d2: some solutions make assumptions about strides
Closed this issue · 1 comments
bmillwood commented
For example:
def strided_trace(a):
N, _ = a.shape
a_strided = torch.as_strided(a, size=(N,), stride=(N+1,))
return torch.sum(a_strided)
should be something like:
def strided_trace(a):
N, _ = a.shape
ns, ms = a.stride()
a_strided = torch.as_strided(a, size=(N,), stride=(ns + ms,))
return torch.sum(a_strided)
I think strided_matmul
has a similar issue? It'll bite you if the matrix you pass in has "weird" strides. It may be useful to write a testcase that feeds in a weird-strided input.
Kiv commented
Thanks Ben! You're right about the issue and I've added test cases for this in the new repo for MLAB2.