keras-team/tf-keras

StringLookup layer never releases memory on load

Closed this issue · 3 comments

System information.

  • Have I written custom code show below
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04 - docker container FROM python:3.10
  • TensorFlow installed from (source or binary): 2.12.0
  • TensorFlow version (use command below): 2.12.0
  • Python version: 3.10.12
  • Exact command to reproduce: see code below

Describe the problem.

Define a simple network with a StringLookup of non-trivial size

Reload the network in a for loop.

Memory goes to infinity.

Describe the current behavior.
Memory goes up every load

Describe the expected behavior.
Memory doesn't go up every load

Standalone code to reproduce the issue.

import time
import os

import tensorflow as tf
import psutil

def model_fn():
    X = tf.keras.Input(shape=(1,), dtype=tf.string)
    lookup = tf.keras.layers.StringLookup(
        vocabulary=tf.constant([str(x) for x in range(100_000)])
    )(X)
    Y = tf.math.reduce_sum(lookup, axis=1)

    return tf.keras.Model(inputs=[X], outputs=[Y])

model = model_fn()
model.save("/tmp/test-model")

loaded_model = None

process = psutil.Process()

def get_current_mem():
    return process.memory_info().rss / 1e6

def load():
    global loaded_model
    loaded_model = tf.saved_model.load("/tmp/test-model")

print('==========================================================')
print("starting process...")

for i in range(100_000):
    start_mem = get_current_mem()
    start = time.time()

    print(f"i={i} loading...", end='')
    load()

    curr_mem = get_current_mem()
    end = time.time()
    print(f"done (mem_usage={curr_mem - start_mem}mb took={int(end - start)}s)")
    time.sleep(0.25)

Here is an example of running the following code.

Notice here that the memory usage goes up each time. As of writing this I am currently using 9.6gb of memory.

image

Finally if I change the StringLookup to a Hashing layer instead, we are fine with memory:

image

One potential solution I've found is to manage the TrackableResources directly myself.

Consider the following code:

prev_model = loaded_model
loaded_model = tf.saved_model.load("/tmp/test-model")

if prev_model:
    lookups = filter(None, [
        layer._trackable_children().get('lookup_table') if hasattr(layer, '_trackable_children') else None
        for layer in prev_model._trackable_children().values()
    ])
    for lookup in lookups:
        tf.raw_ops.DestroyResourceOp(resource=lookup.resource_handle)

    del prev_model

Doing that after reloading appears to clear the resources as expected and not cause memory to grow indefinitely.

For any future readers. V2:

def wipeit(model):
    def _give_handles(m):
        if hasattr(m, 'resource_handle'):
            yield m.resource_handle
    
        children =  {}
        try:
            children = m._trackable_children()
        except:
            pass

        for child in children.values():
            yield from _give_handles(child)


    for handle in _give_handles(model):
        tf.raw_ops.DestroyResourceOp(resource=handle)


...

# during model load, after its loaded
prev_model = loaded_model
loaded_model = tf.saved_model.load(...)

if prev_model:
    wipeit(prev_model)
    del prev_model

@cjmcgraw thanks for filing the issue. This issues seems to be related to TF. This behavior can be reproduced without using any keras APIs. You can just use a lookup table and save and load the model -- the behavior would be the same.