/pytorch-complex-tensor

Unofficial complex tensor and scalar support for Pytorch

Primary LanguagePythonMIT LicenseMIT

Pytorch Complex Tensor

Unofficial complex Tensor support for Pytorch

PyPI version

How it works

Treats first half of tensor as real, second as imaginary. A few arithmetic operations are implemented to emulate complex arithmetic. Supports gradients.

Installation

pip install pytorch-complex-tensor

Example:

Easy import

from pytorch_complex_tensor import ComplexTensor

Init tensor

# equivalent to:
# np.asarray([[1+3j, 1+3j, 1+3j], [2+4j, 2+4j, 2+4j]]).astype(np.complex64)
C = ComplexTensor([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
C.requires_grad = True

Pretty printing

print(C)
# tensor([['(1.0+3.0j)' '(1.0+3.0j)' '(1.0+3.0j)'],
#         ['(2.0+4.0j)' '(2.0+4.0j)' '(2.0+4.0j)']])

handles absolute value properly for complex tensors

# complex absolute value implementation
print(C.abs())
# tensor([[3.1623, 3.1623, 3.1623],
#         [4.4721, 4.4721, 4.4721]], grad_fn=<SqrtBackward>)

prints correct sizing treating first half of matrix as real, second as imag

print(C.size())
# torch.Size([2, 3])

multiplies both complex and real tensors

# show matrix multiply with real tensor
# also works with complex tensor
x = torch.Tensor([[3, 3], [4, 4], [2, 2]])
xy = C.mm(x)
print(xy)
# tensor([['(9.0+27.0j)' '(9.0+27.0j)'],
#         ['(18.0+36.0j)' '(18.0+36.0j)']])

reduce ops return ComplexScalar

xy = xy.sum()

# this is now a complex scalar (thin wrapper with .real, .imag)
print(type(xy))
# pytorch_complex_tensor.complex_scalar.ComplexScalar

print(xy)
# (54+126j)

which can be used for gradients without breaking anything... (differentiates wrt the real part)

# calculate dxy / dC
# for complex scalars, grad is wrt the real part
xy.backward()
print(C.grad)
# tensor([['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)'],
#         ['(6.0-0.0j)' '(8.0-0.0j)' '(4.0-0.0j)']])

supports all section ops...

print(C[-1])
print(C[0, 0:-2, ...])
print(C[0, ..., 0])

Supported ops:

Operation complex tensor real tensor complex scalar real scalar
addition Y Y Y Y
subtraction Y Y Y Y
multiply Y Y Y Y
mm Y Y Y Y
abs Y - - -
t Y - - -
grads Y Y Y Y