StringLookup layer never releases memory on load
Closed this issue · 3 comments
System information.
- Have I written custom code show below
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04 - docker container
FROM python:3.10
- TensorFlow installed from (source or binary): 2.12.0
- TensorFlow version (use command below): 2.12.0
- Python version: 3.10.12
- Exact command to reproduce: see code below
Describe the problem.
Define a simple network with a StringLookup of non-trivial size
Reload the network in a for loop.
Memory goes to infinity.
Describe the current behavior.
Memory goes up every load
Describe the expected behavior.
Memory doesn't go up every load
Standalone code to reproduce the issue.
import time
import os
import tensorflow as tf
import psutil
def model_fn():
X = tf.keras.Input(shape=(1,), dtype=tf.string)
lookup = tf.keras.layers.StringLookup(
vocabulary=tf.constant([str(x) for x in range(100_000)])
)(X)
Y = tf.math.reduce_sum(lookup, axis=1)
return tf.keras.Model(inputs=[X], outputs=[Y])
model = model_fn()
model.save("/tmp/test-model")
loaded_model = None
process = psutil.Process()
def get_current_mem():
return process.memory_info().rss / 1e6
def load():
global loaded_model
loaded_model = tf.saved_model.load("/tmp/test-model")
print('==========================================================')
print("starting process...")
for i in range(100_000):
start_mem = get_current_mem()
start = time.time()
print(f"i={i} loading...", end='')
load()
curr_mem = get_current_mem()
end = time.time()
print(f"done (mem_usage={curr_mem - start_mem}mb took={int(end - start)}s)")
time.sleep(0.25)
Here is an example of running the following code.
Notice here that the memory usage goes up each time. As of writing this I am currently using 9.6gb of memory.
Finally if I change the StringLookup
to a Hashing
layer instead, we are fine with memory:
One potential solution I've found is to manage the TrackableResources
directly myself.
Consider the following code:
prev_model = loaded_model
loaded_model = tf.saved_model.load("/tmp/test-model")
if prev_model:
lookups = filter(None, [
layer._trackable_children().get('lookup_table') if hasattr(layer, '_trackable_children') else None
for layer in prev_model._trackable_children().values()
])
for lookup in lookups:
tf.raw_ops.DestroyResourceOp(resource=lookup.resource_handle)
del prev_model
Doing that after reloading appears to clear the resources as expected and not cause memory to grow indefinitely.
For any future readers. V2:
def wipeit(model):
def _give_handles(m):
if hasattr(m, 'resource_handle'):
yield m.resource_handle
children = {}
try:
children = m._trackable_children()
except:
pass
for child in children.values():
yield from _give_handles(child)
for handle in _give_handles(model):
tf.raw_ops.DestroyResourceOp(resource=handle)
...
# during model load, after its loaded
prev_model = loaded_model
loaded_model = tf.saved_model.load(...)
if prev_model:
wipeit(prev_model)
del prev_model
@cjmcgraw thanks for filing the issue. This issues seems to be related to TF. This behavior can be reproduced without using any keras APIs. You can just use a lookup table and save and load the model -- the behavior would be the same.