BUG: Checkpointing and referencing of variables
Opened this issue · 0 comments
jrmaddison commented
Describe the bug
Firedrake Block
subclasses reference variables via UFL expressions. This can prevent memory usage being reduced by checkpointing.
Firedrake level version of dolfin-adjoint/pyadjoint#169.
Steps to Reproduce
from firedrake import *
from firedrake.adjoint import *
from firedrake.adjoint_utils.blocks.solving import SolveVarFormBlock
from checkpoint_schedules import MultistageCheckpointSchedule
import itertools
N = 100
mesh = UnitIntervalMesh(1)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
tape = get_working_tape()
tape.enable_checkpointing(MultistageCheckpointSchedule(N, 3, 0))
u = Function(space, name="u").interpolate(Constant(2.0))
continue_annotation()
for _ in tape.timestepper(iter(range(N))):
u_ = Function(space)
solve(inner(trial, test) * dx == inner(test, u + u) * dx, u_)
u = u_
del u_
pause_annotation()
del u
deps = set()
for block in tape._blocks:
if isinstance(block, SolveVarFormBlock):
for dep in itertools.chain(ufl.algorithms.extract_coefficients(block.lhs),
ufl.algorithms.extract_coefficients(block.rhs)):
deps.add(dep.count())
print(f"{len(deps)=}")
leads to output
len(deps)=100