diff --git a/README.md b/README.md index 6152150..b34f4ab 100644 --- a/README.md +++ b/README.md @@ -1 +1,14 @@ # DataDrivenControl + +## Algorithms +### Integral Reinforcement Learning (IRL) +- `LinearIRL`: IRL algorithm for affine dynamics and quadratic cost (in control input) + - Update methods + 1. `policy_iteration!` [1, Eqs. (98), (96)], [2] + 2. `value_iteration!` [1, Eqs. (99), (96)] + + +## References +[1] F. L. Lewis, D. Vrabie, and K. G. Vamvoudakis, “Reinforcement Learning and Feedback Control: Using Natural Decision Methods to Design Optimal Adaptive Controllers,” IEEE Control Syst., vol. 32, no. 6, pp. 76–105, Dec. 2012, doi: 10.1109/MCS.2012.2214134. + +[2] D. Vrabie, O. Pastravanu, M. Abu-Khalaf, and F. L. Lewis, “Adaptive Optimal Control for Continuous-Time Linear Systems Based on Policy Iteration,” Automatica, vol. 45, no. 2, pp. 477–484, Feb. 2009, doi: 10.1016/j.automatica.2008.08.017. diff --git a/main/linear_irl.jl b/main/linear_irl.jl new file mode 100644 index 0000000..ea1208c --- /dev/null +++ b/main/linear_irl.jl @@ -0,0 +1,91 @@ +using DataDrivenControl +using FlightSims +using UnPack +using Plots +using LinearAlgebra +using Transducers +using LaTeXStrings +using DiffEqCallbacks + + +struct LinearSystem_ZOH_Gain + linear::LinearSystem + controller::LinearIRL +end + +function FlightSims.State(env::LinearSystem_ZOH_Gain) + @unpack linear = env + State(linear) +end + +function FlightSims.Dynamics!(env::LinearSystem_ZOH_Gain) + @unpack linear, controller = env + @Loggable function dynamics!(dx, x, w, t) + u = optimal_input(controller, x, w) + @onlylog param = w + @onlylog i = controller.i + @nested_log Dynamics!(linear)(dx, x, w, t; u=u) + end +end + +# See [1, Example 2. Continuous-Time Optimal Adaptive Control Using IRL] +# Refs +# [1] “Reinforcement Learning and Feedback Control: Using Natural Decision Methods to Design Optimal Adaptive Controllers,” IEEE Control Syst., vol. 32, no. 6, pp. 76–105, Dec. 2012, doi: 10.1109/MCS.2012.2214134. +function main() + n, m = 2, 1 + A = [ -10 1; + -0.002 -2] + B = [0 2]' + Q = Matrix(I, n, n) + R = Matrix(I, m, m) + cost = DataDrivenControl.QuadraticCost(Q, R) + linear = LinearSystem(A, B) + controller = DataDrivenControl.LinearIRL(Q, R, B; T=0.04) + env = LinearSystem_ZOH_Gain(linear, controller) + x0 = State(env)(rand(2)) + # TODO: add callback to update param `w` + # simulation + P_true = [0.0500 0.0039; + 0.0039 0.2085] # true solution + w_true = [P_true[1, 1], P_true[2, 1]+P_true[1, 2], P_true[2, 2]] + scale = 0.1 + w0 = w_true + scale*randn(3) # perturbed initial guess + tf = 5.0 + simulator = Simulator(x0, Dynamics!(env), w0; + tf=tf, + ) + Δt = 0.01 + # TODO: data buffer is a duplicate of saving callback (see `df`) + # data buffer + buffer = DataDrivenControl.DataBuffer() + function update!(integrator) + t = integrator.t + x = integrator.u # convention of DifferentialEquations.jl + w = integrator.p # convention of DifferentialEquations.jl + u = optimal_input(controller, x, w) # TODO: not to be duplicate of control input in dynamics for stable coding + push!(buffer, controller, cost; + t=t, x=copy(x), u=copy(u), w=copy(w)) + eps = 1e-1 + sc = DistanceStopCondition(eps) + # value_iteration!(controller, buffer, w; sc=sc) + policy_iteration!(controller, buffer, w; sc=sc) + end + cb_irl = PeriodicCallback(update!, controller.T; initial_affect=true) # stack initial data + df = solve(simulator; + callback=cb_irl, + savestep=Δt, + ) + # plot + ts = df.time + xs = df.sol |> Map(datum -> datum.state) |> collect + us = df.sol |> Map(datum -> datum.input) |> collect + ws = df.sol |> Map(datum -> datum.param) |> collect + is = df.sol |> Map(datum -> datum.i) |> collect + fig_x = plot(ts, hcat(xs...)'; label=[L"x_{1}" L"x_{2}"], legend=:outerbottomright) + fig_u = plot(ts, hcat(us...)'; label=L"u", legend=:outerbottomright) + fig_w = plot(ts, hcat(ws...)'; label=[L"w_{1}" L"w_{2}" L"w_{3}"], legend=:outerbottomright) + fig_iter = plot(ts, hcat(is...)'; label="iter") + ws_true = ts |> Map(t -> w_true) |> collect + plot!(fig_w, ts, hcat(ws_true...)'; color=:black, label=nothing) + fig = plot(fig_x, fig_u, fig_w, fig_iter) +end diff --git a/src/DataDrivenControl.jl b/src/DataDrivenControl.jl index f0c35d7..3e031dc 100644 --- a/src/DataDrivenControl.jl +++ b/src/DataDrivenControl.jl @@ -6,6 +6,13 @@ using Transducers ## Costs export QuadraticInInputCost, QuadraticCost +## IRL +# Linear IRL +export LinearIRL, value_iteration!, policy_iteration! +export optimal_input + +## Stop conditions +export DistanceStopCondition include("utils/utils.jl") diff --git a/src/irl/linear_irl.jl b/src/irl/linear_irl.jl index f1394cf..3dde20c 100644 --- a/src/irl/linear_irl.jl +++ b/src/irl/linear_irl.jl @@ -1,5 +1,5 @@ """ -See [1, "IRL Optimal Adaptive Control Using Value Iteration"]. +See [1, "Online Implementation of IRL: A Hybrid Optimal Adaptive Controller"]. # Refs [1] “Reinforcement Learning and Feedback Control: Using Natural Decision Methods to Design Optimal Adaptive Controllers,” IEEE Control Syst., vol. 32, no. 6, pp. 76–105, Dec. 2012, doi: 10.1109/MCS.2012.2214134. @@ -16,6 +16,7 @@ mutable struct LinearIRL <: AbstractIRL N::Int i::Int i_init::Int + terminated::Bool function LinearIRL(Q::AbstractMatrix, R::AbstractMatrix, B::AbstractMatrix; T=0.04, N=nothing, @@ -30,9 +31,19 @@ mutable struct LinearIRL <: AbstractIRL @assert N >= N_min end @assert T > 0 && N > 0 + n1, n2 = size(Q) + @assert n1 == n2 + N_min = Int(n1*(n1 + 1)/2) + if N == nothing + N = N_min + else + @assert N >= N_min + end + @assert T > 0 && N > 0 R_inv = inv(R) i = i_init - new(Q, R_inv, B, T, N, i, i_init) + terminated = false + new(Q, R_inv, B, T, N, i, i_init, terminated) end end @@ -47,22 +58,69 @@ w: critic parameter (vectorised) - ϕs_prev: the vector of bases (evaluated) - V̂: the vector of approximate values (evaluated) """ -function evaluate_policy!(irl::LinearIRL, buffer::DataBuffer, w, +function value_iteration!(irl::LinearIRL, buffer::DataBuffer, w; + sc::AbstractStopCondition=DistanceStopCondition(), ) @unpack i, N = irl @unpack data_array = buffer data_filtered = filter(x -> x.i == i, data_array) # data from the current policy - if length(data_filtered) >= N + if length(data_filtered) >= N + 1 data_sorted = sort(data_filtered, by=x -> x.t) # sort by time index ϕs_prev = data_sorted[end-N:end-1] |> Map(datum -> datum.ϕ) |> collect - V̂s = data_sorted[end-(N-1):end] |> Map(datum -> datum.V̂) |> collect - irl.i += 1 # update iteration number + # V̂s = data_sorted[end-(N-1):end] |> Map(datum -> datum.V̂) |> collect + xs = data_sorted[end-(N-1):end] |> Map(datum -> datum.x) |> collect + ∫rs = data_sorted[end-(N-1):end] |> Map(datum -> datum.∫r) |> collect + P = convert_to_matrix(w) + V̂s_with_prev_P = xs |> Map(x -> value(irl, P, x)) |> collect + V̂s = ∫rs .+ V̂s_with_prev_P # update the critic parameter - w .= ( hcat(V̂s...) * pinv(hcat(ϕs_prev...)) )'[:] # to reduce the computation time; [:] for vectorisation - # w .= pinv(hcat(ϕs_prev...)') * hcat(V̂s...)' # least square sense + w_new = ( hcat(V̂s...) * pinv(hcat(ϕs_prev...)) )'[:] # to reduce the computation time; [:] for vectorisation + if !irl.terminated + if is_terminated(sc, w, w_new) + irl.terminated = true + else + w .= w_new + update_index!(irl) + end + end end end +""" +Policy iteration [1, Eq. 98]; updated in least-square sense +# Notes +w: critic parameter (vectorised) +- Δϕs: the vector of (ϕ - ϕ_prev) (evaluated) +- ∫rs: the vector of integral running cost by numerical integration (evaluated) +""" +function policy_iteration!(irl::LinearIRL, buffer::DataBuffer, w; + sc::AbstractStopCondition=DistanceStopCondition(), + ) + @unpack i, N = irl + @unpack data_array = buffer + data_filtered = filter(x -> x.i == i, data_array) # data from the current policy + if length(data_filtered) >= N + 1 + data_sorted = sort(data_filtered, by=x -> x.t) # sort by time index + ϕs_prev_and_present = data_sorted[end-N:end] |> Map(datum -> datum.ϕ) |> collect + Δϕs = diff(ϕs_prev_and_present) + ∫rs = data_sorted[end-(N-1):end] |> Map(datum -> datum.∫r) |> collect + # update the critic parameter + w_new = ( hcat(∫rs...) * pinv(hcat(-Δϕs...)) )'[:] # to reduce the computation time; [:] for vectorisation + if !irl.terminated + if is_terminated(sc, w, w_new) + irl.terminated = true + else + w .= w_new + update_index!(irl) + end + end + end +end + +function update_index!(irl::LinearIRL) + irl.i += 1 +end + """ Policy improvement [1, Eq. 96]. """ @@ -96,11 +154,11 @@ w: critic parameter (vectorised) function Base.push!(buffer::DataBuffer, irl::LinearIRL, cost::AbstractCost; t, x, u, w, ) - P = convert_to_matrix(w) + # P = convert_to_matrix(w) @unpack data_array = buffer @unpack i = irl data_sorted = sort(data_array, by = x -> x.t) # sorted by t - V̂_with_prev_P = value(irl, P, x) + # V̂_with_prev_P = value(irl, P, x) # prev data if length(data_sorted) != 0 t_prev = data_sorted[end].t @@ -111,9 +169,10 @@ function Base.push!(buffer::DataBuffer, irl::LinearIRL, cost::AbstractCost; r = cost(x, u) r_prev = cost(x_prev, u_prev) ∫r = 0.5 * (r + r_prev) * Δt # trapezoidal - V̂ = ∫r + V̂_with_prev_P + # V̂ = ∫r + V̂_with_prev_P else - V̂ = missing + # V̂ = missing + ∫r = missing end ϕ = convert_quadratic_to_linear_basis(x) # x'Px = w'ϕ(x) datum = (; @@ -122,7 +181,8 @@ function Base.push!(buffer::DataBuffer, irl::LinearIRL, cost::AbstractCost; u=u, w=w, # logging ϕ=ϕ, - V̂=V̂, + ∫r=∫r, + # V̂=V̂, i=i, # iteration number ) push!(buffer.data_array, datum) diff --git a/src/utils/stop_conditions/stop_conditions.jl b/src/utils/stop_conditions/stop_conditions.jl new file mode 100644 index 0000000..a3077f9 --- /dev/null +++ b/src/utils/stop_conditions/stop_conditions.jl @@ -0,0 +1,20 @@ +abstract type AbstractStopCondition end + +function is_terminated(sc::AbstractStopCondition, args...; kwrags...) + error("Defined this method for type: $(typeof(sc))") +end + + +struct DistanceStopCondition <: AbstractStopCondition + eps::Real + p::Real + function DistanceStopCondition(eps=1e-1, p=2) + @assert p > 1 + new(p, eps) + end +end + +function is_terminated(sc::DistanceStopCondition, w, w_new) + @unpack eps, p = sc + norm(w_new - w, p) < eps +end diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 07f09d2..cd1eea6 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,3 +1,4 @@ include("convert.jl") include("data_buffer.jl") include("cost/cost.jl") +include("stop_conditions/stop_conditions.jl") diff --git a/test/irl/irl.jl b/test/irl/irl.jl index b896bae..59a8a66 100644 --- a/test/irl/irl.jl +++ b/test/irl/irl.jl @@ -30,7 +30,8 @@ using LinearAlgebra @test length(buffer.data_array) == length(ts) ŵ_new = deepcopy(ŵ) i = deepcopy(irl.i) - DataDrivenControl.evaluate_policy!(irl, buffer, ŵ_new) + DataDrivenControl.policy_iteration!(irl, buffer, ŵ_new) + DataDrivenControl.value_iteration!(irl, buffer, ŵ_new) @test ŵ_new != ŵ # updated? @test i + 1 == irl.i DataDrivenControl.reset!(irl)