In this graphic we see the basic layout of the system: the neural network controls the system with a force based on the integrated state variables. During training the gradients are taken through the differential equation to better inform the learning process about the physics involved.
In the following the code will be explained, to get to the above graphics. Julia 1.6 was used.
using Flux
using ReachabilityAnalysis
using Plots
using DifferentialEquations
using BSON
using DiffEqFlux
using DiffEqSensitivity
using Optim
Defining some constants, like the length and the masses of the cartpole
M = 1.
m = 1.
L = 1.
g = 9.81
Defining the differential equations for a cart pole, which is force controlled.
function cartpoleC(du,u,p,t)
# theta, s, dtheta, ds = u[1:4]
F = reC(p)([sin(u[1]) , cos(u[1]), u[3]])[1]
du[1] = u[3]
du[2] = u[4]
du[3] = -((M+m)*g*sin(u[1])+m*L^2*sin(u[1])*cos(u[1])*u[3]^2+F*L*cos(u[1]))/(M+m-m*cos(u[1])^2)L^2
du[4] = -(m*L*sin(u[1])*u[3]^2 + m*g*sin(u[1])*cos(u[1])+F*L)/(M+m-m*cos(u[1])^2)*L
return nothing
end
Where reC(p) gives the neural network defined as:
nnC = Chain(Dense(3,32, tanh), Dense(32,1) )
p, reC = Flux.destructure(nnC)
Defining the initial condition for training (cart at zero position, pole facing down, velocities zero):
x0 = zeros(4)
Defining the function for obtaining the values of the solution of the system depending on the parameters given to the network:
predictC(theta) = Array(DifferentialEquations.solve(ODEProblem(cartpoleC, x0, t_span),Tsit5(),p = theta,saveat=Δt, sensealg=ReverseDiffAdjoint()))
Defining the loss function which penalizes deviations from the zero position of the cart and deviations from the pole standing upright:
function lossC(theta)
pr = predictC(theta)
# theta, dtheta
loss = 10*sum(abs2,(3.14 .- pr[1,end])) + 0.6*sum(abs2,(pr[2,end]))
end
Defining the time span in which the system will be trained and the time step:
t_span = (0., 1.)
Δt = 0.01
Defining a callback function for showing the loss values:
callback = function (p, l)
display(l)
return false
end
Having a look at the loss with the initial parameters:
lossC(p)
First, training the network with an ADAM-optimizer:
res1 = DiffEqFlux.sciml_train(lossC, p, Flux.ADAM(.1), maxiters=100, cb = callback)
How did the loss improve?
lossC(res1.u)
Switching to a Quasi-Newton solver
res2 = DiffEqFlux.sciml_train(lossC, res1.minimizer, BFGS(initial_stepnorm = 0.01), maxiters = 150, allow_f_increases = true, cb = callback)
The loss can further be altered (punishing more velocities at the end) and minimized by more optimization steps
Different loss definition:
function lossC(theta)
pr = predictC(theta)
# theta, dtheta
loss = 10*sum(abs2,(3.14 .- pr[1,end])) + 0.6*sum(abs2,(pr[2,end])) + 1*sum(abs2,(pr[3,end])) + 1*sum(abs2,(pr[4,end]))
end
res3 = DiffEqFlux.sciml_train(lossC, res2.minimizer, BFGS(initial_stepnorm = 0.01), maxiters = 150, allow_f_increases = true, cb = callback)
lossC(res4.u)
Plot of ODE-solution with very low loss:
plot(t_span[1]:Δt:t_span[2], predictC(res4.u)')
hline!([pi,0])
Now we can verify the behavior of the system with the package ReachabilityAnalysis.jl
Therefore we redefine the differential equation of the cartpole to use the already trained weights (res4.u).
function cartpoleCV(du,u,p,t) # theta, s, dtheta, ds = u[1:4]
F = reC(res4.u)([sin(u[1]) , cos(u[1]), u[3]])[1]
du[1] = u[3]
du[2] = u[4]
du[3] = -((M+m)*g*sin(u[1])+m*L^2*sin(u[1])*cos(u[1])*u[3]^2+F*L*cos(u[1]))/(M+m-m*cos(u[1])^2)L^2
du[4] = -(m*L*sin(u[1])*u[3]^2 + m*g*sin(u[1])*cos(u[1])+F*L)/(M+m-m*cos(u[1])^2)*L
return nothing
end
Then defining an initial set rather vector:
U0 = Hyperrectangle(zeros(4), ones(4) .* 0.01)
Defining the initial value problem for the reachbility analysis:
ivp_cartpole = @ivp(u' = cartpoleCV(u), dim=4, u(0) ∈ U0)
Solving the problem and the overapproximating the flow pip with zonotopes:
sol_cartpole = ReachabilityAnalysis.solve(ivp_cartpole, (0.0, 2.0));
zsol_cartpole = overapproximate(sol_cartpole, Zonotope);
Plotting gives the above result:
p2 = plot(zsol_cartpole, vars=(0,1), label="Reach Θ")
plot!(zsol_cartpole, vars=(0,2), label = " Reach S")
plot!(zsol_cartpole, vars=(0,3), label = " Reach dΘ")
plot!(zsol_cartpole, vars=(0,4), label = "Reach dS")
hline!([pi,0], label = "Target", xlabel="t in s", ylabel = "Angle in rad, Distance in m" )