X10 Tensor Performance for LSTMs
tanmayb123 opened this issue · 4 comments
I'm running an experiment to compare the performance of LSTMs on Swift for TensorFlow and TensorFlow in Python. I'm using the following (badly written) code:
import time
import tensorflow as tf
@tf.function(experimental_compile=True)
def lstm(ih, hh, b, ts_input, ts_hidden, ts_cell, hiddensize):
z = tf.linalg.matmul(ts_input, ih) + tf.linalg.matmul(ts_hidden, hh) + b
z0 = z[:, 0:hiddensize]
z1 = z[:, hiddensize:hiddensize*2]
z2 = z[:, hiddensize*2:hiddensize*3]
z3 = z[:, hiddensize*3:]
i = tf.math.sigmoid(z0)
f = tf.math.sigmoid(z1)
c = f * ts_cell + i * tf.math.sigmoid(z2)
o = tf.math.sigmoid(z3)
h = o * tf.math.tanh(c)
return (h, c)
def run_prediction(ih, hh, b, hiddensize, inputs):
hidden = tf.zeros((inputs.shape[1], hiddensize))
cell = tf.zeros((inputs.shape[1], hiddensize))
hiddens = [hidden]
for i in range(0, inputs.shape[0]):
i = tf.constant(i)
hidden, cell = lstm(ih, hh, b, inputs[i], hidden, cell, hiddensize)
hiddens.append(hidden)
return hiddens
ih = tf.random.uniform((26, 256*4))
hh = tf.random.uniform((256, 256*4))
b = tf.random.uniform((256*4,))
hiddensize = tf.constant(256)
inputs = tf.random.uniform((380, 128, 26))
def run():
s = time.time()
print(run_prediction(ih, hh, b, hiddensize, inputs)[-1].shape)
e = time.time()
print(e - s)
run()
run()
import Foundation
import TensorFlow
let device: Device = .defaultXLA
struct LSTMOutput: Differentiable {
var hidden: Tensor<Float>
var cell: Tensor<Float>
}
@differentiable(wrt: (ih, hh, b))
func lstm(ih: Tensor<Float>,
hh: Tensor<Float>,
b: Tensor<Float>,
tsInput: Tensor<Float>,
tsHidden: Tensor<Float>,
tsCell: Tensor<Float>,
hiddenSize: Int) -> LSTMOutput {
let z = matmul(tsInput, ih) + matmul(tsHidden, hh) + b
let z0 = z.slice(lowerBounds: [0, 0], upperBounds: [z.shape[0], hiddenSize])
let z1 = z.slice(lowerBounds: [0, hiddenSize], upperBounds: [z.shape[0], hiddenSize * 2])
let z2 = z.slice(lowerBounds: [0, hiddenSize * 2], upperBounds: [z.shape[0], hiddenSize * 3])
let z3 = z.slice(lowerBounds: [0, hiddenSize * 3], upperBounds: [z.shape[0], hiddenSize * 4])
let i = sigmoid(z0)
let f = sigmoid(z1)
let c = f * tsCell + i * sigmoid(z2)
let o = sigmoid(z3)
let h = o * tanh(c)
return .init(hidden: h, cell: c)
}
@differentiable(wrt: (ih, hh, b))
func runPrediction(ih: Tensor<Float>,
hh: Tensor<Float>,
b: Tensor<Float>,
hiddenSize: Int,
inputs: [Tensor<Float>]) -> [Tensor<Float>] {
var hidden = Tensor<Float>(zeros: [inputs[0].shape[0], hiddenSize], on: device)
var cell = Tensor<Float>(zeros: [inputs[0].shape[0], hiddenSize], on: device)
var hiddens: [Tensor<Float>] = [hidden]
for i in 0..<withoutDerivative(at: inputs.count) {
let result = lstm(ih: ih, hh: hh, b: b, tsInput: inputs[i], tsHidden: hidden, tsCell: cell, hiddenSize: hiddenSize)
hidden = result.hidden
cell = result.cell
hiddens.append(hidden)
}
return hiddens
}
let ih = Tensor<Float>(randomUniform: [26, 256*4], on: device)
let hh = Tensor<Float>(randomUniform: [256, 256*4], on: device)
let b = Tensor<Float>(randomUniform: [256*4], on: device)
let hiddenSize = 256
let inputs: [Tensor<Float>] = (1...380).map { _ in Tensor(randomUniform: [128, 26], on: device) }
func run() {
let start = Date().timeIntervalSince1970
print(runPrediction(ih: ih, hh: hh, b: b, hiddenSize: hiddenSize, inputs: inputs).last!.shape)
let end = Date().timeIntervalSince1970
print(end - start)
}
run()
run()
Python gives me the following output:
2020-05-09 23:32:56.595915: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-05-09 23:32:56.606559: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f8598660a50 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-05-09 23:32:56.606573: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version
2020-05-09 23:32:56.857759: I tensorflow/compiler/jit/xla_compilation_cache.cc:242] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
(128, 256)
0.544409990310669
(128, 256)
0.27017807960510254
Swift (with .defaultTFEager
) gives me:
[128, 256]
0.40241098403930664
[128, 256]
0.37356019020080566
Swift (with .defaultXLA
) gives me:
2020-05-09 23:33:09.025920: I tensorflow/compiler/xla/xla_client/xrt_local_service.cc:54] Peer localservice 1 {localhost:35392}
2020-05-09 23:33:09.026451: I tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.2 AVX AVX2 FMA
2020-05-09 23:33:09.075110: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x11b126e30 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-05-09 23:33:09.075126: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version
2020-05-09 23:33:09.077361: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job localservice -> {0 -> localhost:35392}
2020-05-09 23:33:09.077748: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:390] Started server with target: grpc://localhost:35392
2020-05-09 23:33:09.084266: W tensorflow/compiler/jit/xla_device.cc:398] XLA_GPU and XLA_CPU devices are deprecated and will be removed in subsequent releases. Instead, use either @tf.function(experimental_compile=True) for must-compile semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 for auto-clustering best-effort compilation.
[128, 256]
1.3604919910430908
[128, 256]
1.3101921081542969
I was wondering, what causes so much extra overhead on the XLA device in Swift? Is this an issue with the way I've written my code, or is it an issue with how the tensors are implemented? If so, is it a known issue and are there plans to fix it soon?
If so, then I can use S4TF for training the LSTMs in my next project.
This experiment was run on a MacBook Pro, the file was run using swiftc -O main.swift && ./main
.
Thanks!
(I had originally put this issue here tensorflow/swift#461 but I thought this repo would be more appropriate)
@tanmayb123 Have a look at https://github.com/tensorflow/swift-apis/blob/master/Sources/x10/swift_bindings/doc/TROUBLESHOOTING.md, it explains the possible issues and the process of investigating them in detail. PrintX10Metrics
is especially interesting, I suspect we're seeing either recompilation or calling into operations which break the graph.
In any case, I'll have a look next week since you provided a repro. I'm optimistic we can make it fast. Thanks!
Thanks Alex! I put printX10Metrics()
after each run()
, and here's the output on my computer:
2020-05-10 14:04:58.734379: I tensorflow/compiler/xla/xla_client/xrt_local_service.cc:54] Peer localservice 1 {localhost:36954}
2020-05-10 14:04:58.735080: I tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.2 AVX AVX2 FMA
2020-05-10 14:04:58.786298: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x11d176470 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-05-10 14:04:58.786313: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version
2020-05-10 14:04:58.788425: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:301] Initialize GrpcChannelCache for job localservice -> {0 -> localhost:36954}
2020-05-10 14:04:58.788727: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:390] Started server with target: grpc://localhost:36954
2020-05-10 14:04:58.795367: W tensorflow/compiler/jit/xla_device.cc:398] XLA_GPU and XLA_CPU devices are deprecated and will be removed in subsequent releases. Instead, use either @tf.function(experimental_compile=True) for must-compile semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 for auto-clustering best-effort compilation.
[128, 256]
1.3609950542449951
2020-05-10 14:05:00.252778: I swift_bindings/xla_tensor_wrapper.cc:790] Metrics:
Metric: CompileTime
TotalSamples: 1
Accumulator: 017ms749.888us
Percentiles: 1%=017ms749.888us; 5%=017ms749.888us; 10%=017ms749.888us; 20%=017ms749.888us; 50%=017ms749.888us; 80%=017ms749.888us; 90%=017ms749.888us; 95%=017ms749.888us; 99%=017ms749.888us
Metric: DeviceLockWait
TotalSamples: 1520
Accumulator: 001ms184.225us
ValueRate: 910.637us / second
Rate: 1168.11 / second
Percentiles: 1%=000.689us; 5%=000.724us; 10%=000.738us; 20%=000.751us; 50%=000.776us; 80%=000.802us; 90%=000.819us; 95%=000.826us; 99%=000.860us
Metric: ExecuteTime
TotalSamples: 1520
Accumulator: 344ms749.074us
ValueRate: 245ms499.143us / second
Rate: 1168.1 / second
Percentiles: 1%=194.156us; 5%=197.375us; 10%=199.300us; 20%=201.852us; 50%=207.161us; 80%=214.259us; 90%=219.850us; 95%=225.339us; 99%=291.208us
Metric: InboundData
TotalSamples: 1520
Accumulator: 11.88KB
ValueRate: 9.13KB / second
Rate: 1168.11 / second
Percentiles: 1%=8.00B; 5%=8.00B; 10%=8.00B; 20%=8.00B; 50%=8.00B; 80%=8.00B; 90%=8.00B; 95%=8.00B; 99%=8.00B
Metric: IrValueTensorToXlaData
TotalSamples: 3423
Accumulator: 636ms791.698us
ValueRate: 428ms040.837us / second
Rate: 2364.04 / second
Percentiles: 1%=163.701us; 5%=167.896us; 10%=169.601us; 20%=172.005us; 50%=178.523us; 80%=188.350us; 90%=193.929us; 95%=200.828us; 99%=220.842us
Metric: OutboundData
TotalSamples: 3423
Accumulator: 26.74KB
ValueRate: 18.47KB / second
Rate: 2364.05 / second
Percentiles: 1%=8.00B; 5%=8.00B; 10%=8.00B; 20%=8.00B; 50%=8.00B; 80%=8.00B; 90%=8.00B; 95%=8.00B; 99%=8.00B
Metric: ReleaseDataHandlesTime
TotalSamples: 3025
Accumulator: 437ms038.050us
ValueRate: 295ms002.624us / second
Rate: 2355.69 / second
Percentiles: 1%=107.345us; 5%=111.899us; 10%=114.060us; 20%=117.392us; 50%=123.852us; 80%=131.508us; 90%=136.867us; 95%=143.509us; 99%=157.632us
Metric: TensorsGraphSize
TotalSamples: 1520
Accumulator: 4560.00
ValueRate: 3504.33 / second
Rate: 1168.11 / second
Percentiles: 1%=3.00; 5%=3.00; 10%=3.00; 20%=3.00; 50%=3.00; 80%=3.00; 90%=3.00; 95%=3.00; 99%=3.00
Metric: TransferFromServerTime
TotalSamples: 1520
Accumulator: 228ms173.109us
ValueRate: 168ms440.013us / second
Rate: 1168.11 / second
Percentiles: 1%=135.492us; 5%=137.016us; 10%=138.238us; 20%=139.487us; 50%=142.259us; 80%=146.686us; 90%=149.639us; 95%=153.425us; 99%=178.000us
Metric: TransferToServerTime
TotalSamples: 3423
Accumulator: 625ms925.315us
ValueRate: 421ms502.200us / second
Rate: 2364.04 / second
Percentiles: 1%=161.119us; 5%=164.660us; 10%=166.415us; 20%=168.701us; 50%=175.328us; 80%=185.100us; 90%=191.066us; 95%=197.863us; 99%=217.845us
Metric: TransferToServerTransformTime
TotalSamples: 3423
Accumulator: 069ms109.508us
ValueRate: 046ms875.380us / second
Rate: 2364.05 / second
Percentiles: 1%=016.444us; 5%=017.261us; 10%=017.865us; 20%=018.315us; 50%=019.141us; 80%=020.367us; 90%=020.988us; 95%=021.701us; 99%=024.795us
Counter: CachedCompile
Value: 1519
Counter: CreateCompileHandles
Value: 1
Counter: CreateDataHandles
Value: 4943
Counter: CreateXlaTensor
Value: 16376
Counter: DestroyDataHandles
Value: 4560
Counter: DestroyXlaTensor
Value: 15993
Counter: ReleaseDataHandles
Value: 4560
Counter: UncachedCompile
Value: 1
Counter: XRTAllocateFromTensor_Empty
Value: 1
Counter: XrtCompile_Empty
Value: 80
Counter: XrtExecuteChained_Empty
Value: 80
Counter: XrtExecute_Empty
Value: 80
Counter: XrtRead_Empty
Value: 80
Counter: XrtReleaseAllocationHandle_Empty
Value: 80
Counter: XrtReleaseCompileHandle_Empty
Value: 80
Counter: XrtSessionCount
Value: 6
Counter: XrtSubTuple_Empty
Value: 80
[128, 256]
1.3143398761749268
2020-05-10 14:05:01.569849: I swift_bindings/xla_tensor_wrapper.cc:790] Metrics:
Metric: CompileTime
TotalSamples: 1
Accumulator: 017ms749.888us
Percentiles: 1%=017ms749.888us; 5%=017ms749.888us; 10%=017ms749.888us; 20%=017ms749.888us; 50%=017ms749.888us; 80%=017ms749.888us; 90%=017ms749.888us; 95%=017ms749.888us; 99%=017ms749.888us
Metric: DeviceLockWait
TotalSamples: 3040
Accumulator: 002ms391.199us
ValueRate: 925.636us / second
Rate: 1163.08 / second
Percentiles: 1%=000.692us; 5%=000.720us; 10%=000.737us; 20%=000.750us; 50%=000.778us; 80%=000.808us; 90%=000.828us; 95%=000.838us; 99%=000.879us
Metric: ExecuteTime
TotalSamples: 3040
Accumulator: 669ms636.628us
ValueRate: 248ms589.716us / second
Rate: 1163.13 / second
Percentiles: 1%=194.776us; 5%=198.421us; 10%=200.340us; 20%=203.568us; 50%=209.921us; 80%=218.339us; 90%=224.766us; 95%=234.304us; 99%=262.793us
Metric: InboundData
TotalSamples: 3040
Accumulator: 23.75KB
ValueRate: 9.09KB / second
Rate: 1163.13 / second
Percentiles: 1%=8.00B; 5%=8.00B; 10%=8.00B; 20%=8.00B; 50%=8.00B; 80%=8.00B; 90%=8.00B; 95%=8.00B; 99%=8.00B
Metric: IrValueTensorToXlaData
TotalSamples: 6463
Accumulator: 01s202ms439.505us
ValueRate: 432ms677.997us / second
Rate: 2324.28 / second
Percentiles: 1%=166.073us; 5%=169.206us; 10%=171.161us; 20%=173.942us; 50%=181.884us; 80%=192.092us; 90%=202.007us; 95%=214.134us; 99%=276.403us
Metric: OutboundData
TotalSamples: 6463
Accumulator: 50.49KB
ValueRate: 18.16KB / second
Rate: 2324.29 / second
Percentiles: 1%=8.00B; 5%=8.00B; 10%=8.00B; 20%=8.00B; 50%=8.00B; 80%=8.00B; 90%=8.00B; 95%=8.00B; 99%=8.00B
Metric: ReleaseDataHandlesTime
TotalSamples: 6052
Accumulator: 830ms432.265us
ValueRate: 299ms708.863us / second
Rate: 2307.4 / second
Percentiles: 1%=109.134us; 5%=114.307us; 10%=116.776us; 20%=119.445us; 50%=127.032us; 80%=136.125us; 90%=144.214us; 95%=151.779us; 99%=186.353us
Metric: TensorsGraphSize
TotalSamples: 3040
Accumulator: 9120.00
ValueRate: 3489.23 / second
Rate: 1163.08 / second
Percentiles: 1%=3.00; 5%=3.00; 10%=3.00; 20%=3.00; 50%=3.00; 80%=3.00; 90%=3.00; 95%=3.00; 99%=3.00
Metric: TransferFromServerTime
TotalSamples: 3040
Accumulator: 450ms452.323us
ValueRate: 170ms100.630us / second
Rate: 1163.13 / second
Percentiles: 1%=135.276us; 5%=137.325us; 10%=138.467us; 20%=140.148us; 50%=143.362us; 80%=147.595us; 90%=152.353us; 95%=160.734us; 99%=223.151us
Metric: TransferToServerTime
TotalSamples: 6463
Accumulator: 01s182ms789.125us
ValueRate: 424ms184.007us / second
Rate: 2324.28 / second
Percentiles: 1%=162.921us; 5%=166.040us; 10%=167.950us; 20%=170.700us; 50%=178.610us; 80%=188.897us; 90%=199.042us; 95%=210.882us; 99%=272.675us
Metric: TransferToServerTransformTime
TotalSamples: 6463
Accumulator: 130ms781.995us
ValueRate: 046ms462.256us / second
Rate: 2324.29 / second
Percentiles: 1%=016.422us; 5%=017.588us; 10%=018.087us; 20%=018.516us; 50%=019.299us; 80%=020.727us; 90%=022.177us; 95%=023.871us; 99%=033.486us
Counter: CachedCompile
Value: 3039
Counter: CreateCompileHandles
Value: 1
Counter: CreateDataHandles
Value: 9503
Counter: CreateXlaTensor
Value: 28922
Counter: DestroyDataHandles
Value: 9120
Counter: DestroyXlaTensor
Value: 28539
Counter: ReleaseDataHandles
Value: 9120
Counter: UncachedCompile
Value: 1
Counter: XRTAllocateFromTensor_Empty
Value: 1
Counter: XrtCompile_Empty
Value: 80
Counter: XrtExecuteChained_Empty
Value: 80
Counter: XrtExecute_Empty
Value: 80
Counter: XrtRead_Empty
Value: 80
Counter: XrtReleaseAllocationHandle_Empty
Value: 80
Counter: XrtReleaseCompileHandle_Empty
Value: 80
Counter: XrtSessionCount
Value: 6
Counter: XrtSubTuple_Empty
Value: 80
A sequence of fixes in master, starting with #982, have fixed this. The nightly toolchain should work a lot better with LSTMs now.
On my Linux machine, the original example is 35x (!) faster with x10 than eager. However, note that the original example doesn't look at the actual contents of the tensor, just the shape - which means x10 never needs to execute anything, which explains the enormous gap.
A more realistic scenario, which prints the sum of the prediction instead of the shape, still shows a 1.7x speedup with x10 vs eager, on my CPU. On my GPU both backends are about as fast.
Note that the first run()
call will take a very long time (~3 minutes) because the XLA compilation itself is slow for this example, I'm only reporting the second run. Subsequent runs will be as fast as the second run.
Last but not least, the current focus of x10 is Linux and TPUs / GPUs, so ymmv on macOS with CPU.
I'm going to close this issue for now as it appears to be solved.