Skip to content
Merged
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,14 @@
# DataDrivenControl

## Algorithms
### Integral Reinforcement Learning (IRL)
- `LinearIRL`: IRL algorithm for affine dynamics and quadratic cost (in control input)
- Update methods
1. `policy_iteration!` [1, Eqs. (98), (96)], [2]
2. `value_iteration!` [1, Eqs. (99), (96)]


## References
[1] F. L. Lewis, D. Vrabie, and K. G. Vamvoudakis, “Reinforcement Learning and Feedback Control: Using Natural Decision Methods to Design Optimal Adaptive Controllers,” IEEE Control Syst., vol. 32, no. 6, pp. 76–105, Dec. 2012, doi: 10.1109/MCS.2012.2214134.

[2] D. Vrabie, O. Pastravanu, M. Abu-Khalaf, and F. L. Lewis, “Adaptive Optimal Control for Continuous-Time Linear Systems Based on Policy Iteration,” Automatica, vol. 45, no. 2, pp. 477–484, Feb. 2009, doi: 10.1016/j.automatica.2008.08.017.
91 changes: 91 additions & 0 deletions main/linear_irl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using DataDrivenControl
using FlightSims
using UnPack
using Plots
using LinearAlgebra
using Transducers
using LaTeXStrings
using DiffEqCallbacks


struct LinearSystem_ZOH_Gain
linear::LinearSystem
controller::LinearIRL
end

function FlightSims.State(env::LinearSystem_ZOH_Gain)
@unpack linear = env
State(linear)
end

function FlightSims.Dynamics!(env::LinearSystem_ZOH_Gain)
@unpack linear, controller = env
@Loggable function dynamics!(dx, x, w, t)
u = optimal_input(controller, x, w)
@onlylog param = w
@onlylog i = controller.i
@nested_log Dynamics!(linear)(dx, x, w, t; u=u)
end
end

# See [1, Example 2. Continuous-Time Optimal Adaptive Control Using IRL]
# Refs
# [1] “Reinforcement Learning and Feedback Control: Using Natural Decision Methods to Design Optimal Adaptive Controllers,” IEEE Control Syst., vol. 32, no. 6, pp. 76–105, Dec. 2012, doi: 10.1109/MCS.2012.2214134.
function main()
n, m = 2, 1
A = [ -10 1;
-0.002 -2]
B = [0 2]'
Q = Matrix(I, n, n)
R = Matrix(I, m, m)
cost = DataDrivenControl.QuadraticCost(Q, R)
linear = LinearSystem(A, B)
controller = DataDrivenControl.LinearIRL(Q, R, B; T=0.04)
env = LinearSystem_ZOH_Gain(linear, controller)
x0 = State(env)(rand(2))
# TODO: add callback to update param `w`
# simulation
P_true = [0.0500 0.0039;
0.0039 0.2085] # true solution
w_true = [P_true[1, 1], P_true[2, 1]+P_true[1, 2], P_true[2, 2]]
scale = 0.1
w0 = w_true + scale*randn(3) # perturbed initial guess
tf = 5.0
simulator = Simulator(x0, Dynamics!(env), w0;
tf=tf,
)
Δt = 0.01
# TODO: data buffer is a duplicate of saving callback (see `df`)
# data buffer
buffer = DataDrivenControl.DataBuffer()
function update!(integrator)
t = integrator.t
x = integrator.u # convention of DifferentialEquations.jl
w = integrator.p # convention of DifferentialEquations.jl
u = optimal_input(controller, x, w) # TODO: not to be duplicate of control input in dynamics for stable coding
push!(buffer, controller, cost;
t=t, x=copy(x), u=copy(u), w=copy(w))
eps = 1e-1
sc = DistanceStopCondition(eps)
# value_iteration!(controller, buffer, w; sc=sc)
policy_iteration!(controller, buffer, w; sc=sc)
end
cb_irl = PeriodicCallback(update!, controller.T; initial_affect=true) # stack initial data
df = solve(simulator;
callback=cb_irl,
savestep=Δt,
)
# plot
ts = df.time
xs = df.sol |> Map(datum -> datum.state) |> collect
us = df.sol |> Map(datum -> datum.input) |> collect
ws = df.sol |> Map(datum -> datum.param) |> collect
is = df.sol |> Map(datum -> datum.i) |> collect
fig_x = plot(ts, hcat(xs...)'; label=[L"x_{1}" L"x_{2}"], legend=:outerbottomright)
fig_u = plot(ts, hcat(us...)'; label=L"u", legend=:outerbottomright)
fig_w = plot(ts, hcat(ws...)'; label=[L"w_{1}" L"w_{2}" L"w_{3}"], legend=:outerbottomright)
fig_iter = plot(ts, hcat(is...)'; label="iter")
ws_true = ts |> Map(t -> w_true) |> collect
plot!(fig_w, ts, hcat(ws_true...)'; color=:black, label=nothing)
fig = plot(fig_x, fig_u, fig_w, fig_iter)
end
7 changes: 7 additions & 0 deletions src/DataDrivenControl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ using Transducers

## Costs
export QuadraticInInputCost, QuadraticCost
## IRL
# Linear IRL
export LinearIRL, value_iteration!, policy_iteration!
export optimal_input

## Stop conditions
export DistanceStopCondition


include("utils/utils.jl")
Expand Down
86 changes: 73 additions & 13 deletions src/irl/linear_irl.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
See [1, "IRL Optimal Adaptive Control Using Value Iteration"].
See [1, "Online Implementation of IRL: A Hybrid Optimal Adaptive Controller"].

# Refs
[1] “Reinforcement Learning and Feedback Control: Using Natural Decision Methods to Design Optimal Adaptive Controllers,” IEEE Control Syst., vol. 32, no. 6, pp. 76–105, Dec. 2012, doi: 10.1109/MCS.2012.2214134.
Expand All @@ -16,6 +16,7 @@ mutable struct LinearIRL <: AbstractIRL
N::Int
i::Int
i_init::Int
terminated::Bool
function LinearIRL(Q::AbstractMatrix, R::AbstractMatrix, B::AbstractMatrix;
T=0.04,
N=nothing,
Expand All @@ -30,9 +31,19 @@ mutable struct LinearIRL <: AbstractIRL
@assert N >= N_min
end
@assert T > 0 && N > 0
n1, n2 = size(Q)
@assert n1 == n2
N_min = Int(n1*(n1 + 1)/2)
if N == nothing
N = N_min
else
@assert N >= N_min
end
@assert T > 0 && N > 0
R_inv = inv(R)
i = i_init
new(Q, R_inv, B, T, N, i, i_init)
terminated = false
new(Q, R_inv, B, T, N, i, i_init, terminated)
end
end

Expand All @@ -47,22 +58,69 @@ w: critic parameter (vectorised)
- ϕs_prev: the vector of bases (evaluated)
- V̂: the vector of approximate values (evaluated)
"""
function evaluate_policy!(irl::LinearIRL, buffer::DataBuffer, w,
function value_iteration!(irl::LinearIRL, buffer::DataBuffer, w;
sc::AbstractStopCondition=DistanceStopCondition(),
)
@unpack i, N = irl
@unpack data_array = buffer
data_filtered = filter(x -> x.i == i, data_array) # data from the current policy
if length(data_filtered) >= N
if length(data_filtered) >= N + 1
data_sorted = sort(data_filtered, by=x -> x.t) # sort by time index
ϕs_prev = data_sorted[end-N:end-1] |> Map(datum -> datum.ϕ) |> collect
V̂s = data_sorted[end-(N-1):end] |> Map(datum -> datum.V̂) |> collect
irl.i += 1 # update iteration number
# V̂s = data_sorted[end-(N-1):end] |> Map(datum -> datum.V̂) |> collect
xs = data_sorted[end-(N-1):end] |> Map(datum -> datum.x) |> collect
∫rs = data_sorted[end-(N-1):end] |> Map(datum -> datum.∫r) |> collect
P = convert_to_matrix(w)
V̂s_with_prev_P = xs |> Map(x -> value(irl, P, x)) |> collect
V̂s = ∫rs .+ V̂s_with_prev_P
# update the critic parameter
w .= ( hcat(V̂s...) * pinv(hcat(ϕs_prev...)) )'[:] # to reduce the computation time; [:] for vectorisation
# w .= pinv(hcat(ϕs_prev...)') * hcat(V̂s...)' # least square sense
w_new = ( hcat(V̂s...) * pinv(hcat(ϕs_prev...)) )'[:] # to reduce the computation time; [:] for vectorisation
if !irl.terminated
if is_terminated(sc, w, w_new)
irl.terminated = true
else
w .= w_new
update_index!(irl)
end
end
end
end

"""
Policy iteration [1, Eq. 98]; updated in least-square sense
# Notes
w: critic parameter (vectorised)
- Δϕs: the vector of (ϕ - ϕ_prev) (evaluated)
- ∫rs: the vector of integral running cost by numerical integration (evaluated)
"""
function policy_iteration!(irl::LinearIRL, buffer::DataBuffer, w;
sc::AbstractStopCondition=DistanceStopCondition(),
)
@unpack i, N = irl
@unpack data_array = buffer
data_filtered = filter(x -> x.i == i, data_array) # data from the current policy
if length(data_filtered) >= N + 1
data_sorted = sort(data_filtered, by=x -> x.t) # sort by time index
ϕs_prev_and_present = data_sorted[end-N:end] |> Map(datum -> datum.ϕ) |> collect
Δϕs = diff(ϕs_prev_and_present)
∫rs = data_sorted[end-(N-1):end] |> Map(datum -> datum.∫r) |> collect
# update the critic parameter
w_new = ( hcat(∫rs...) * pinv(hcat(-Δϕs...)) )'[:] # to reduce the computation time; [:] for vectorisation
if !irl.terminated
if is_terminated(sc, w, w_new)
irl.terminated = true
else
w .= w_new
update_index!(irl)
end
end
end
end

function update_index!(irl::LinearIRL)
irl.i += 1
end

"""
Policy improvement [1, Eq. 96].
"""
Expand Down Expand Up @@ -96,11 +154,11 @@ w: critic parameter (vectorised)
function Base.push!(buffer::DataBuffer, irl::LinearIRL, cost::AbstractCost;
t, x, u, w,
)
P = convert_to_matrix(w)
# P = convert_to_matrix(w)
@unpack data_array = buffer
@unpack i = irl
data_sorted = sort(data_array, by = x -> x.t) # sorted by t
V̂_with_prev_P = value(irl, P, x)
# V̂_with_prev_P = value(irl, P, x)
# prev data
if length(data_sorted) != 0
t_prev = data_sorted[end].t
Expand All @@ -111,9 +169,10 @@ function Base.push!(buffer::DataBuffer, irl::LinearIRL, cost::AbstractCost;
r = cost(x, u)
r_prev = cost(x_prev, u_prev)
∫r = 0.5 * (r + r_prev) * Δt # trapezoidal
V̂ = ∫r + V̂_with_prev_P
# V̂ = ∫r + V̂_with_prev_P
else
V̂ = missing
# V̂ = missing
∫r = missing
end
ϕ = convert_quadratic_to_linear_basis(x) # x'Px = w'ϕ(x)
datum = (;
Expand All @@ -122,7 +181,8 @@ function Base.push!(buffer::DataBuffer, irl::LinearIRL, cost::AbstractCost;
u=u,
w=w, # logging
ϕ=ϕ,
V̂=V̂,
∫r=∫r,
# V̂=V̂,
i=i, # iteration number
)
push!(buffer.data_array, datum)
Expand Down
20 changes: 20 additions & 0 deletions src/utils/stop_conditions/stop_conditions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
abstract type AbstractStopCondition end

function is_terminated(sc::AbstractStopCondition, args...; kwrags...)
error("Defined this method for type: $(typeof(sc))")
end


struct DistanceStopCondition <: AbstractStopCondition
eps::Real
p::Real
function DistanceStopCondition(eps=1e-1, p=2)
@assert p > 1
new(p, eps)
end
end

function is_terminated(sc::DistanceStopCondition, w, w_new)
@unpack eps, p = sc
norm(w_new - w, p) < eps
end
1 change: 1 addition & 0 deletions src/utils/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include("convert.jl")
include("data_buffer.jl")
include("cost/cost.jl")
include("stop_conditions/stop_conditions.jl")
3 changes: 2 additions & 1 deletion test/irl/irl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ using LinearAlgebra
@test length(buffer.data_array) == length(ts)
ŵ_new = deepcopy(ŵ)
i = deepcopy(irl.i)
DataDrivenControl.evaluate_policy!(irl, buffer, ŵ_new)
DataDrivenControl.policy_iteration!(irl, buffer, ŵ_new)
DataDrivenControl.value_iteration!(irl, buffer, ŵ_new)
@test ŵ_new != ŵ # updated?
@test i + 1 == irl.i
DataDrivenControl.reset!(irl)
Expand Down