firedrakeproject/firedrake

BUG: Checkpointing and referencing of variables

Opened this issue · 0 comments

Describe the bug
Firedrake Block subclasses reference variables via UFL expressions. This can prevent memory usage being reduced by checkpointing.

Firedrake level version of dolfin-adjoint/pyadjoint#169.

Steps to Reproduce

from firedrake import *
from firedrake.adjoint import *
from firedrake.adjoint_utils.blocks.solving import SolveVarFormBlock
from checkpoint_schedules import MultistageCheckpointSchedule

import itertools

N = 100

mesh = UnitIntervalMesh(1)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)

tape = get_working_tape()
tape.enable_checkpointing(MultistageCheckpointSchedule(N, 3, 0))

u = Function(space, name="u").interpolate(Constant(2.0))
continue_annotation()
for _ in tape.timestepper(iter(range(N))):
    u_ = Function(space)
    solve(inner(trial, test) * dx == inner(test, u + u) * dx, u_)
    u = u_
    del u_
pause_annotation()
del u

deps = set()
for block in tape._blocks:
    if isinstance(block, SolveVarFormBlock):
        for dep in itertools.chain(ufl.algorithms.extract_coefficients(block.lhs),
                                   ufl.algorithms.extract_coefficients(block.rhs)):
            deps.add(dep.count())

print(f"{len(deps)=}")

leads to output

len(deps)=100