From 7d13081d5fbfedb2720ef22262aadf886e7642a3 Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Thu, 30 Dec 2021 22:25:05 +0900 Subject: [PATCH 01/17] Modify update! (now it's evaluate_policy!) --- src/irl/linear_irl.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/irl/linear_irl.jl b/src/irl/linear_irl.jl index f1394cf..e314b72 100644 --- a/src/irl/linear_irl.jl +++ b/src/irl/linear_irl.jl @@ -7,6 +7,8 @@ See [1, "IRL Optimal Adaptive Control Using Value Iteration"]. # Notes - T: Data stack period - N: The maximum length of stacked data +- ϕs_prev: the vector of bases (evaluated) +- V̂: the vector of approximate values (evaluated) """ mutable struct LinearIRL <: AbstractIRL Q::AbstractMatrix @@ -30,6 +32,14 @@ mutable struct LinearIRL <: AbstractIRL @assert N >= N_min end @assert T > 0 && N > 0 + n1, n2 = size(Q) + @assert n1 == n2 + N_min = n1*(n1 + 1)/2 + if N == nothing + N = N_min + else + @assert N >= N_min + end R_inv = inv(R) i = i_init new(Q, R_inv, B, T, N, i, i_init) From 87cc34e7669c615248cb27779d067e517e5b8028 Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Thu, 30 Dec 2021 23:22:53 +0900 Subject: [PATCH 02/17] Fix bugs --- src/irl/linear_irl.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/irl/linear_irl.jl b/src/irl/linear_irl.jl index e314b72..2ae3d40 100644 --- a/src/irl/linear_irl.jl +++ b/src/irl/linear_irl.jl @@ -34,12 +34,13 @@ mutable struct LinearIRL <: AbstractIRL @assert T > 0 && N > 0 n1, n2 = size(Q) @assert n1 == n2 - N_min = n1*(n1 + 1)/2 + N_min = Int(n1*(n1 + 1)/2) if N == nothing N = N_min else @assert N >= N_min end + @assert T > 0 && N > 0 R_inv = inv(R) i = i_init new(Q, R_inv, B, T, N, i, i_init) From 50b556815c2dda6c6a9e111a28fe1d04da1f653c Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Thu, 30 Dec 2021 23:23:46 +0900 Subject: [PATCH 03/17] wip --- src/irl/linear_irl.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/irl/linear_irl.jl b/src/irl/linear_irl.jl index 2ae3d40..59bcd59 100644 --- a/src/irl/linear_irl.jl +++ b/src/irl/linear_irl.jl @@ -7,8 +7,6 @@ See [1, "IRL Optimal Adaptive Control Using Value Iteration"]. # Notes - T: Data stack period - N: The maximum length of stacked data -- ϕs_prev: the vector of bases (evaluated) -- V̂: the vector of approximate values (evaluated) """ mutable struct LinearIRL <: AbstractIRL Q::AbstractMatrix From b06fcf98c4a2399d7cf31a9f73950886c55915ca Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Fri, 31 Dec 2021 20:25:20 +0900 Subject: [PATCH 04/17] Add a test script for numerical simulation --- main/linear_irl.jl | 81 ++++++++++++++++++++++++++++++++++++++++ src/DataDrivenControl.jl | 2 + 2 files changed, 83 insertions(+) create mode 100644 main/linear_irl.jl diff --git a/main/linear_irl.jl b/main/linear_irl.jl new file mode 100644 index 0000000..43df145 --- /dev/null +++ b/main/linear_irl.jl @@ -0,0 +1,81 @@ +using DataDrivenControl +using FlightSims +using UnPack +using Plots +using LinearAlgebra +using Transducers +using LaTeXStrings + + +struct LinearSystem_ZOH_Gain + linear::LinearSystem + controller::LinearIRL +end + +function FlightSims.State(env::LinearSystem_ZOH_Gain) + @unpack linear = env + State(linear) +end + +function FlightSims.Dynamics!(env::LinearSystem_ZOH_Gain) + @unpack linear, controller = env + @Loggable function dynamics!(dx, x, w, t) + u = optimal_input(controller, x, w) + @onlylog param = w + @nested_log Dynamics!(linear)(dx, x, w, t; u=u) + end +end + +# See [1, Example 2. Continuous-Time Optimal Adaptive Control Using IRL] +# Refs +# [1] “Reinforcement Learning and Feedback Control: Using Natural Decision Methods to Design Optimal Adaptive Controllers,” IEEE Control Syst., vol. 32, no. 6, pp. 76–105, Dec. 2012, doi: 10.1109/MCS.2012.2214134. +function main() + n, m = 2, 1 + A = [ -10 1; + -0.002 -2] + B = [0 2]' + Q = Matrix(I, n, n) + R = Matrix(I, m, m) + cost = DataDrivenControl.QuadraticCost(Q, R) + linear = LinearSystem(A, B) + controller = DataDrivenControl.LinearIRL(Q, R, B) + env = LinearSystem_ZOH_Gain(linear, controller) + x0 = State(env)(rand(2)) + # TODO: add callback to update param `w` + # simulation + P = [0.0500 0.0039; + 0.0039 0.2085] # true solution + w = [P[1, 1], P[2, 1]+P[1, 2], P[2, 2]] # true solution + tf = 10.0 + simulator = Simulator(x0, Dynamics!(env), w; + tf=tf, + ) + Δt = controller.T + df = solve(simulator; savestep=Δt) + # plot + ts = df.time + xs = df.sol |> Map(datum -> datum.state) |> collect + us = df.sol |> Map(datum -> datum.input) |> collect + ws = df.sol |> Map(datum -> datum.param) |> collect + fig_x = plot(hcat(xs...)'; label=[L"x_{1}" L"x_{2}"]) + fig_u = plot(hcat(us...)'; label=L"u") + fig_w = plot(hcat(ws...)'; label=[L"w_{1}" L"w_{2}" L"w_{3}"]) + fig = plot(fig_x, fig_u, fig_w) + + + # # data buffer + # buffer = DataDrivenControl.DataBuffer() + # # TODO + # ts = 0:irl.T:1 + # for t in ts + # push!(buffer, irl, cost; t=t, x=x, u=û, w=ŵ) + # end + # @test length(buffer.data_array) == length(ts) + # ŵ_new = deepcopy(ŵ) + # i = deepcopy(irl.i) + # DataDrivenControl.evaluate_policy!(irl, buffer, ŵ_new) + # @test ŵ_new != ŵ # updated? + # @test i + 1 == irl.i + # DataDrivenControl.reset!(irl) + # @test irl.i == irl.i_init +end diff --git a/src/DataDrivenControl.jl b/src/DataDrivenControl.jl index f0c35d7..22a69df 100644 --- a/src/DataDrivenControl.jl +++ b/src/DataDrivenControl.jl @@ -6,6 +6,8 @@ using Transducers ## Costs export QuadraticInInputCost, QuadraticCost +export LinearIRL +export optimal_input include("utils/utils.jl") From bd8088c35233457751f431f5a9f07f8d6af8d46f Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Fri, 31 Dec 2021 20:30:13 +0900 Subject: [PATCH 05/17] Fix a bug for time-axis --- main/linear_irl.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main/linear_irl.jl b/main/linear_irl.jl index 43df145..e2613ae 100644 --- a/main/linear_irl.jl +++ b/main/linear_irl.jl @@ -57,9 +57,9 @@ function main() xs = df.sol |> Map(datum -> datum.state) |> collect us = df.sol |> Map(datum -> datum.input) |> collect ws = df.sol |> Map(datum -> datum.param) |> collect - fig_x = plot(hcat(xs...)'; label=[L"x_{1}" L"x_{2}"]) - fig_u = plot(hcat(us...)'; label=L"u") - fig_w = plot(hcat(ws...)'; label=[L"w_{1}" L"w_{2}" L"w_{3}"]) + fig_x = plot(ts, hcat(xs...)'; label=[L"x_{1}" L"x_{2}"]) + fig_u = plot(ts, hcat(us...)'; label=L"u") + fig_w = plot(ts, hcat(ws...)'; label=[L"w_{1}" L"w_{2}" L"w_{3}"]) fig = plot(fig_x, fig_u, fig_w) From 634abf7d8967fd67ddb9e7381e6a7b2dda79c2df Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Fri, 31 Dec 2021 20:48:05 +0900 Subject: [PATCH 06/17] Add linear irl test --- main/linear_irl.jl | 58 +++++++++++++++++++++------------------- src/DataDrivenControl.jl | 4 ++- src/irl/linear_irl.jl | 2 +- 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/main/linear_irl.jl b/main/linear_irl.jl index e2613ae..620617c 100644 --- a/main/linear_irl.jl +++ b/main/linear_irl.jl @@ -5,6 +5,7 @@ using Plots using LinearAlgebra using Transducers using LaTeXStrings +using DiffEqCallbacks struct LinearSystem_ZOH_Gain @@ -38,44 +39,47 @@ function main() R = Matrix(I, m, m) cost = DataDrivenControl.QuadraticCost(Q, R) linear = LinearSystem(A, B) - controller = DataDrivenControl.LinearIRL(Q, R, B) + controller = DataDrivenControl.LinearIRL(Q, R, B; T=0.04) env = LinearSystem_ZOH_Gain(linear, controller) x0 = State(env)(rand(2)) # TODO: add callback to update param `w` # simulation - P = [0.0500 0.0039; - 0.0039 0.2085] # true solution - w = [P[1, 1], P[2, 1]+P[1, 2], P[2, 2]] # true solution + P_true = [0.0500 0.0039; + 0.0039 0.2085] # true solution + w_true = [P_true[1, 1], P_true[2, 1]+P_true[1, 2], P_true[2, 2]] + scale = 0.1 + w0 = w_true + scale*randn(3) # perturbed initial guess tf = 10.0 - simulator = Simulator(x0, Dynamics!(env), w; + simulator = Simulator(x0, Dynamics!(env), w0; tf=tf, ) - Δt = controller.T - df = solve(simulator; savestep=Δt) + Δt = 0.01 + # TODO: data buffer is a duplicate of saving callback (see `df`) + # data buffer + buffer = DataDrivenControl.DataBuffer() + function update!(integrator) + t = integrator.t + x = integrator.u # convention of DifferentialEquations.jl + w = integrator.p # convention of DifferentialEquations.jl + u = optimal_input(controller, x, w) # TODO: not to be duplicate of control input in dynamics for stable coding + push!(buffer, controller, cost; + t=t, x=copy(x), u=copy(u), w=copy(w)) + evaluate_policy!(controller, buffer, w) + end + cb_irl = PeriodicCallback(update!, controller.T; initial_affect=true) # stack initial data + df = solve(simulator; + callback=cb_irl, + savestep=Δt, + ) # plot ts = df.time xs = df.sol |> Map(datum -> datum.state) |> collect us = df.sol |> Map(datum -> datum.input) |> collect ws = df.sol |> Map(datum -> datum.param) |> collect - fig_x = plot(ts, hcat(xs...)'; label=[L"x_{1}" L"x_{2}"]) - fig_u = plot(ts, hcat(us...)'; label=L"u") - fig_w = plot(ts, hcat(ws...)'; label=[L"w_{1}" L"w_{2}" L"w_{3}"]) + fig_x = plot(ts, hcat(xs...)'; label=[L"x_{1}" L"x_{2}"], legend=:outerbottomright) + fig_u = plot(ts, hcat(us...)'; label=L"u", legend=:outerbottomright) + fig_w = plot(ts, hcat(ws...)'; label=[L"w_{1}" L"w_{2}" L"w_{3}"], legend=:outerbottomright) + ws_true = ts |> Map(t -> w_true) |> collect + plot!(fig_w, ts, hcat(ws_true...)'; color=:black, label=nothing) fig = plot(fig_x, fig_u, fig_w) - - - # # data buffer - # buffer = DataDrivenControl.DataBuffer() - # # TODO - # ts = 0:irl.T:1 - # for t in ts - # push!(buffer, irl, cost; t=t, x=x, u=û, w=ŵ) - # end - # @test length(buffer.data_array) == length(ts) - # ŵ_new = deepcopy(ŵ) - # i = deepcopy(irl.i) - # DataDrivenControl.evaluate_policy!(irl, buffer, ŵ_new) - # @test ŵ_new != ŵ # updated? - # @test i + 1 == irl.i - # DataDrivenControl.reset!(irl) - # @test irl.i == irl.i_init end diff --git a/src/DataDrivenControl.jl b/src/DataDrivenControl.jl index 22a69df..b829eeb 100644 --- a/src/DataDrivenControl.jl +++ b/src/DataDrivenControl.jl @@ -6,7 +6,9 @@ using Transducers ## Costs export QuadraticInInputCost, QuadraticCost -export LinearIRL +## IRL +# Linear IRL +export LinearIRL, evaluate_policy! export optimal_input diff --git a/src/irl/linear_irl.jl b/src/irl/linear_irl.jl index 59bcd59..705379c 100644 --- a/src/irl/linear_irl.jl +++ b/src/irl/linear_irl.jl @@ -61,7 +61,7 @@ function evaluate_policy!(irl::LinearIRL, buffer::DataBuffer, w, @unpack i, N = irl @unpack data_array = buffer data_filtered = filter(x -> x.i == i, data_array) # data from the current policy - if length(data_filtered) >= N + if length(data_filtered) >= N + 1 data_sorted = sort(data_filtered, by=x -> x.t) # sort by time index ϕs_prev = data_sorted[end-N:end-1] |> Map(datum -> datum.ϕ) |> collect V̂s = data_sorted[end-(N-1):end] |> Map(datum -> datum.V̂) |> collect From 6ffed39c991e874e2acd6f861b0e99c70b7766de Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Sun, 2 Jan 2022 14:27:15 +0900 Subject: [PATCH 07/17] evaluate_policy! to value_iteration! --- main/linear_irl.jl | 2 +- src/DataDrivenControl.jl | 2 +- src/irl/linear_irl.jl | 22 ++++++++++++++-------- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/main/linear_irl.jl b/main/linear_irl.jl index 620617c..ad1525d 100644 --- a/main/linear_irl.jl +++ b/main/linear_irl.jl @@ -64,7 +64,7 @@ function main() u = optimal_input(controller, x, w) # TODO: not to be duplicate of control input in dynamics for stable coding push!(buffer, controller, cost; t=t, x=copy(x), u=copy(u), w=copy(w)) - evaluate_policy!(controller, buffer, w) + value_iteration!(controller, buffer, w) end cb_irl = PeriodicCallback(update!, controller.T; initial_affect=true) # stack initial data df = solve(simulator; diff --git a/src/DataDrivenControl.jl b/src/DataDrivenControl.jl index b829eeb..e5c8878 100644 --- a/src/DataDrivenControl.jl +++ b/src/DataDrivenControl.jl @@ -8,7 +8,7 @@ using Transducers export QuadraticInInputCost, QuadraticCost ## IRL # Linear IRL -export LinearIRL, evaluate_policy! +export LinearIRL, value_iteration! export optimal_input diff --git a/src/irl/linear_irl.jl b/src/irl/linear_irl.jl index 705379c..d590f44 100644 --- a/src/irl/linear_irl.jl +++ b/src/irl/linear_irl.jl @@ -56,15 +56,19 @@ w: critic parameter (vectorised) - ϕs_prev: the vector of bases (evaluated) - V̂: the vector of approximate values (evaluated) """ -function evaluate_policy!(irl::LinearIRL, buffer::DataBuffer, w, - ) +function value_iteration!(irl::LinearIRL, buffer::DataBuffer, w) @unpack i, N = irl @unpack data_array = buffer data_filtered = filter(x -> x.i == i, data_array) # data from the current policy if length(data_filtered) >= N + 1 data_sorted = sort(data_filtered, by=x -> x.t) # sort by time index ϕs_prev = data_sorted[end-N:end-1] |> Map(datum -> datum.ϕ) |> collect - V̂s = data_sorted[end-(N-1):end] |> Map(datum -> datum.V̂) |> collect + # V̂s = data_sorted[end-(N-1):end] |> Map(datum -> datum.V̂) |> collect + xs = data_sorted[end-(N-1):end] |> Map(datum -> datum.x) |> collect + ∫rs = data_sorted[end-(N-1):end] |> Map(datum -> datum.∫r) |> collect + P = convert_to_matrix(w) + V̂s_with_prev_P = xs |> Map(x -> value(irl, P, x)) |> collect + V̂s = ∫rs .+ V̂s_with_prev_P irl.i += 1 # update iteration number # update the critic parameter w .= ( hcat(V̂s...) * pinv(hcat(ϕs_prev...)) )'[:] # to reduce the computation time; [:] for vectorisation @@ -105,11 +109,11 @@ w: critic parameter (vectorised) function Base.push!(buffer::DataBuffer, irl::LinearIRL, cost::AbstractCost; t, x, u, w, ) - P = convert_to_matrix(w) + # P = convert_to_matrix(w) @unpack data_array = buffer @unpack i = irl data_sorted = sort(data_array, by = x -> x.t) # sorted by t - V̂_with_prev_P = value(irl, P, x) + # V̂_with_prev_P = value(irl, P, x) # prev data if length(data_sorted) != 0 t_prev = data_sorted[end].t @@ -120,9 +124,10 @@ function Base.push!(buffer::DataBuffer, irl::LinearIRL, cost::AbstractCost; r = cost(x, u) r_prev = cost(x_prev, u_prev) ∫r = 0.5 * (r + r_prev) * Δt # trapezoidal - V̂ = ∫r + V̂_with_prev_P + # V̂ = ∫r + V̂_with_prev_P else - V̂ = missing + # V̂ = missing + ∫r = missing end ϕ = convert_quadratic_to_linear_basis(x) # x'Px = w'ϕ(x) datum = (; @@ -131,7 +136,8 @@ function Base.push!(buffer::DataBuffer, irl::LinearIRL, cost::AbstractCost; u=u, w=w, # logging ϕ=ϕ, - V̂=V̂, + ∫r=∫r, + # V̂=V̂, i=i, # iteration number ) push!(buffer.data_array, datum) From 1ae62ab14cfb793a5769ed931d23d8ea254b5e17 Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Sun, 2 Jan 2022 14:31:07 +0900 Subject: [PATCH 08/17] Add `policy_iteration!` --- main/linear_irl.jl | 3 ++- src/DataDrivenControl.jl | 2 +- src/irl/linear_irl.jl | 22 +++++++++++++++++++++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/main/linear_irl.jl b/main/linear_irl.jl index ad1525d..8c73495 100644 --- a/main/linear_irl.jl +++ b/main/linear_irl.jl @@ -64,7 +64,8 @@ function main() u = optimal_input(controller, x, w) # TODO: not to be duplicate of control input in dynamics for stable coding push!(buffer, controller, cost; t=t, x=copy(x), u=copy(u), w=copy(w)) - value_iteration!(controller, buffer, w) + # value_iteration!(controller, buffer, w) + policy_iteration!(controller, buffer, w) end cb_irl = PeriodicCallback(update!, controller.T; initial_affect=true) # stack initial data df = solve(simulator; diff --git a/src/DataDrivenControl.jl b/src/DataDrivenControl.jl index e5c8878..b88f2c3 100644 --- a/src/DataDrivenControl.jl +++ b/src/DataDrivenControl.jl @@ -8,7 +8,7 @@ using Transducers export QuadraticInInputCost, QuadraticCost ## IRL # Linear IRL -export LinearIRL, value_iteration! +export LinearIRL, value_iteration!, policy_iteration! export optimal_input diff --git a/src/irl/linear_irl.jl b/src/irl/linear_irl.jl index d590f44..a7115d2 100644 --- a/src/irl/linear_irl.jl +++ b/src/irl/linear_irl.jl @@ -69,13 +69,33 @@ function value_iteration!(irl::LinearIRL, buffer::DataBuffer, w) P = convert_to_matrix(w) V̂s_with_prev_P = xs |> Map(x -> value(irl, P, x)) |> collect V̂s = ∫rs .+ V̂s_with_prev_P - irl.i += 1 # update iteration number # update the critic parameter w .= ( hcat(V̂s...) * pinv(hcat(ϕs_prev...)) )'[:] # to reduce the computation time; [:] for vectorisation # w .= pinv(hcat(ϕs_prev...)') * hcat(V̂s...)' # least square sense + update_index!(irl) end end +function policy_iteration!(irl::LinearIRL, buffer::DataBuffer, w) + @unpack i, N = irl + @unpack data_array = buffer + data_filtered = filter(x -> x.i == i, data_array) # data from the current policy + if length(data_filtered) >= N + 1 + data_sorted = sort(data_filtered, by=x -> x.t) # sort by time index + ϕs_prev_and_present = data_sorted[end-N:end] |> Map(datum -> datum.ϕ) |> collect + Δϕs = diff(ϕs_prev_and_present) + ∫rs = data_sorted[end-(N-1):end] |> Map(datum -> datum.∫r) |> collect + # update the critic parameter + w .= ( hcat(∫rs...) * pinv(hcat(-Δϕs...)) )'[:] # to reduce the computation time; [:] for vectorisation + # w .= pinv(hcat(-Δϕs...)') * hcat(∫rs...)' # least square sense + update_index!(irl) + end +end + +function update_index!(irl::LinearIRL) + irl.i += 1 +end + """ Policy improvement [1, Eq. 96]. """ From 1466f91c589f012df3f544a1f2268c30d418a449 Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Sun, 2 Jan 2022 14:36:49 +0900 Subject: [PATCH 09/17] Add docs --- src/irl/linear_irl.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/irl/linear_irl.jl b/src/irl/linear_irl.jl index a7115d2..fd32f09 100644 --- a/src/irl/linear_irl.jl +++ b/src/irl/linear_irl.jl @@ -1,5 +1,5 @@ """ -See [1, "IRL Optimal Adaptive Control Using Value Iteration"]. +See [1, "Online Implementation of IRL: A Hybrid Optimal Adaptive Controller"]. # Refs [1] “Reinforcement Learning and Feedback Control: Using Natural Decision Methods to Design Optimal Adaptive Controllers,” IEEE Control Syst., vol. 32, no. 6, pp. 76–105, Dec. 2012, doi: 10.1109/MCS.2012.2214134. @@ -76,6 +76,13 @@ function value_iteration!(irl::LinearIRL, buffer::DataBuffer, w) end end +""" +Policy iteration [1, Eq. 98]; updated in least-square sense +# Notes +w: critic parameter (vectorised) +- Δϕs: the vector of (ϕ - ϕ_prev) (evaluated) +- ∫rs: the vector of integral running cost by numerical integration (evaluated) +""" function policy_iteration!(irl::LinearIRL, buffer::DataBuffer, w) @unpack i, N = irl @unpack data_array = buffer From d7a06a81a287a2b3760dd3e610e244cadc5e205e Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Mon, 3 Jan 2022 10:54:25 +0900 Subject: [PATCH 10/17] Fix test --- test/irl/irl.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/irl/irl.jl b/test/irl/irl.jl index b896bae..59a8a66 100644 --- a/test/irl/irl.jl +++ b/test/irl/irl.jl @@ -30,7 +30,8 @@ using LinearAlgebra @test length(buffer.data_array) == length(ts) ŵ_new = deepcopy(ŵ) i = deepcopy(irl.i) - DataDrivenControl.evaluate_policy!(irl, buffer, ŵ_new) + DataDrivenControl.policy_iteration!(irl, buffer, ŵ_new) + DataDrivenControl.value_iteration!(irl, buffer, ŵ_new) @test ŵ_new != ŵ # updated? @test i + 1 == irl.i DataDrivenControl.reset!(irl) From 909d7128063f30165957c9b8582ff6f9ff2ff8e2 Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Mon, 3 Jan 2022 18:45:40 +0900 Subject: [PATCH 11/17] Add stop cond --- main/linear_irl.jl | 11 ++++-- src/DataDrivenControl.jl | 3 ++ src/irl/linear_irl.jl | 36 +++++++++++++++----- src/utils/stop_conditions/stop_conditions.jl | 20 +++++++++++ src/utils/utils.jl | 1 + 5 files changed, 59 insertions(+), 12 deletions(-) create mode 100644 src/utils/stop_conditions/stop_conditions.jl diff --git a/main/linear_irl.jl b/main/linear_irl.jl index 8c73495..9fadeb0 100644 --- a/main/linear_irl.jl +++ b/main/linear_irl.jl @@ -23,6 +23,7 @@ function FlightSims.Dynamics!(env::LinearSystem_ZOH_Gain) @Loggable function dynamics!(dx, x, w, t) u = optimal_input(controller, x, w) @onlylog param = w + @onlylog i = controller.i @nested_log Dynamics!(linear)(dx, x, w, t; u=u) end end @@ -64,8 +65,10 @@ function main() u = optimal_input(controller, x, w) # TODO: not to be duplicate of control input in dynamics for stable coding push!(buffer, controller, cost; t=t, x=copy(x), u=copy(u), w=copy(w)) - # value_iteration!(controller, buffer, w) - policy_iteration!(controller, buffer, w) + eps = 1e-1 + sc = DistanceStopCondition(eps) + value_iteration!(controller, buffer, w; sc=sc) + # policy_iteration!(controller, buffer, w; sc=sc) end cb_irl = PeriodicCallback(update!, controller.T; initial_affect=true) # stack initial data df = solve(simulator; @@ -77,10 +80,12 @@ function main() xs = df.sol |> Map(datum -> datum.state) |> collect us = df.sol |> Map(datum -> datum.input) |> collect ws = df.sol |> Map(datum -> datum.param) |> collect + is = df.sol |> Map(datum -> datum.i) |> collect fig_x = plot(ts, hcat(xs...)'; label=[L"x_{1}" L"x_{2}"], legend=:outerbottomright) fig_u = plot(ts, hcat(us...)'; label=L"u", legend=:outerbottomright) fig_w = plot(ts, hcat(ws...)'; label=[L"w_{1}" L"w_{2}" L"w_{3}"], legend=:outerbottomright) + fig_iter = plot(ts, hcat(is...)'; label="iter") ws_true = ts |> Map(t -> w_true) |> collect plot!(fig_w, ts, hcat(ws_true...)'; color=:black, label=nothing) - fig = plot(fig_x, fig_u, fig_w) + fig = plot(fig_x, fig_u, fig_w, fig_iter) end diff --git a/src/DataDrivenControl.jl b/src/DataDrivenControl.jl index b88f2c3..3e031dc 100644 --- a/src/DataDrivenControl.jl +++ b/src/DataDrivenControl.jl @@ -11,6 +11,9 @@ export QuadraticInInputCost, QuadraticCost export LinearIRL, value_iteration!, policy_iteration! export optimal_input +## Stop conditions +export DistanceStopCondition + include("utils/utils.jl") include("irl/irl.jl") diff --git a/src/irl/linear_irl.jl b/src/irl/linear_irl.jl index fd32f09..c858378 100644 --- a/src/irl/linear_irl.jl +++ b/src/irl/linear_irl.jl @@ -16,6 +16,7 @@ mutable struct LinearIRL <: AbstractIRL N::Int i::Int i_init::Int + stopped::Bool function LinearIRL(Q::AbstractMatrix, R::AbstractMatrix, B::AbstractMatrix; T=0.04, N=nothing, @@ -41,7 +42,8 @@ mutable struct LinearIRL <: AbstractIRL @assert T > 0 && N > 0 R_inv = inv(R) i = i_init - new(Q, R_inv, B, T, N, i, i_init) + stopped = false + new(Q, R_inv, B, T, N, i, i_init, stopped) end end @@ -56,7 +58,9 @@ w: critic parameter (vectorised) - ϕs_prev: the vector of bases (evaluated) - V̂: the vector of approximate values (evaluated) """ -function value_iteration!(irl::LinearIRL, buffer::DataBuffer, w) +function value_iteration!(irl::LinearIRL, buffer::DataBuffer, w; + sc::AbstractStopCondition=DistanceStopCondition(eps), + ) @unpack i, N = irl @unpack data_array = buffer data_filtered = filter(x -> x.i == i, data_array) # data from the current policy @@ -70,9 +74,15 @@ function value_iteration!(irl::LinearIRL, buffer::DataBuffer, w) V̂s_with_prev_P = xs |> Map(x -> value(irl, P, x)) |> collect V̂s = ∫rs .+ V̂s_with_prev_P # update the critic parameter - w .= ( hcat(V̂s...) * pinv(hcat(ϕs_prev...)) )'[:] # to reduce the computation time; [:] for vectorisation - # w .= pinv(hcat(ϕs_prev...)') * hcat(V̂s...)' # least square sense - update_index!(irl) + w_new = ( hcat(V̂s...) * pinv(hcat(ϕs_prev...)) )'[:] # to reduce the computation time; [:] for vectorisation + if !irl.stopped + if is_stopped(sc, w, w_new) + irl.stopped = true + else + w .= w_new + update_index!(irl) + end + end end end @@ -83,7 +93,9 @@ w: critic parameter (vectorised) - Δϕs: the vector of (ϕ - ϕ_prev) (evaluated) - ∫rs: the vector of integral running cost by numerical integration (evaluated) """ -function policy_iteration!(irl::LinearIRL, buffer::DataBuffer, w) +function policy_iteration!(irl::LinearIRL, buffer::DataBuffer, w; + sc::AbstractStopCondition=DistanceStopCondition(eps), + ) @unpack i, N = irl @unpack data_array = buffer data_filtered = filter(x -> x.i == i, data_array) # data from the current policy @@ -93,9 +105,15 @@ function policy_iteration!(irl::LinearIRL, buffer::DataBuffer, w) Δϕs = diff(ϕs_prev_and_present) ∫rs = data_sorted[end-(N-1):end] |> Map(datum -> datum.∫r) |> collect # update the critic parameter - w .= ( hcat(∫rs...) * pinv(hcat(-Δϕs...)) )'[:] # to reduce the computation time; [:] for vectorisation - # w .= pinv(hcat(-Δϕs...)') * hcat(∫rs...)' # least square sense - update_index!(irl) + w_new = ( hcat(∫rs...) * pinv(hcat(-Δϕs...)) )'[:] # to reduce the computation time; [:] for vectorisation + if !irl.stopped + if is_stopped(sc, w, w_new) + irl.stopped = true + else + w .= w_new + update_index!(irl) + end + end end end diff --git a/src/utils/stop_conditions/stop_conditions.jl b/src/utils/stop_conditions/stop_conditions.jl new file mode 100644 index 0000000..a3cbd0f --- /dev/null +++ b/src/utils/stop_conditions/stop_conditions.jl @@ -0,0 +1,20 @@ +abstract type AbstractStopCondition end + +function is_stopped(sc::AbstractStopCondition, args...; kwrags...) + error("Defined this method for type: $(typeof(sc))") +end + + +struct DistanceStopCondition <: AbstractStopCondition + eps::Real + p::Real + function DistanceStopCondition(eps=1e-3, p=2) + @assert p > 1 + new(p, eps) + end +end + +function is_stopped(sc::DistanceStopCondition, w, w_new) + @unpack eps, p = sc + norm(w_new - w, p) < eps +end diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 07f09d2..cd1eea6 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -1,3 +1,4 @@ include("convert.jl") include("data_buffer.jl") include("cost/cost.jl") +include("stop_conditions/stop_conditions.jl") From 056edd3e2b8866ae559695f9084e94a7dc35b8e7 Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Mon, 3 Jan 2022 18:46:58 +0900 Subject: [PATCH 12/17] Change tf for visualisation --- main/linear_irl.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main/linear_irl.jl b/main/linear_irl.jl index 9fadeb0..ea1208c 100644 --- a/main/linear_irl.jl +++ b/main/linear_irl.jl @@ -50,7 +50,7 @@ function main() w_true = [P_true[1, 1], P_true[2, 1]+P_true[1, 2], P_true[2, 2]] scale = 0.1 w0 = w_true + scale*randn(3) # perturbed initial guess - tf = 10.0 + tf = 5.0 simulator = Simulator(x0, Dynamics!(env), w0; tf=tf, ) @@ -67,8 +67,8 @@ function main() t=t, x=copy(x), u=copy(u), w=copy(w)) eps = 1e-1 sc = DistanceStopCondition(eps) - value_iteration!(controller, buffer, w; sc=sc) - # policy_iteration!(controller, buffer, w; sc=sc) + # value_iteration!(controller, buffer, w; sc=sc) + policy_iteration!(controller, buffer, w; sc=sc) end cb_irl = PeriodicCallback(update!, controller.T; initial_affect=true) # stack initial data df = solve(simulator; From 2bc7b22f16c236d6ce1ac659225e3cb5aff39e1d Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Mon, 3 Jan 2022 18:54:44 +0900 Subject: [PATCH 13/17] wip --- src/utils/stop_conditions/stop_conditions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/stop_conditions/stop_conditions.jl b/src/utils/stop_conditions/stop_conditions.jl index a3cbd0f..5f754ca 100644 --- a/src/utils/stop_conditions/stop_conditions.jl +++ b/src/utils/stop_conditions/stop_conditions.jl @@ -8,7 +8,7 @@ end struct DistanceStopCondition <: AbstractStopCondition eps::Real p::Real - function DistanceStopCondition(eps=1e-3, p=2) + function DistanceStopCondition(eps=1e-1, p=2) @assert p > 1 new(p, eps) end From 966bced06ef38c91244532a6dbf737d99802365b Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Mon, 3 Jan 2022 19:10:55 +0900 Subject: [PATCH 14/17] Add README --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index 6152150..20b06e1 100644 --- a/README.md +++ b/README.md @@ -1 +1,14 @@ # DataDrivenControl + +## Algorithms +### Integral Reinforcement Learning (IRL) +- `LinearIRL`: IRL algorithm for affine dynamics and quadratic cost (in control input) + - Update methods + 1. `policy_iteration!` [1, Eqs. (98), (96)], [2] + 2. `value_iteration!` [1, Eqs. (99), (96)] + + +## References +[1] F. L. Lewis, D. Vrabie, and K. G. Vamvoudakis, “Reinforcement Learning and Feedback Control: Using Natural Decision Methods to Design Optimal Adaptive Controllers,” IEEE Control Syst., vol. 32, no. 6, pp. 76–105, Dec. 2012, doi: 10.1109/MCS.2012.2214134. +[2] D. Vrabie, O. Pastravanu, M. Abu-Khalaf, and F. L. Lewis, “Adaptive Optimal Control for Continuous-Time Linear Systems Based on Policy Iteration,” Automatica, vol. 45, no. 2, pp. 477–484, Feb. 2009, doi: 10.1016/j.automatica.2008.08.017. + From 2cbbdb6f7d9815d025116605eab587ebdef1dbfc Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Mon, 3 Jan 2022 19:11:25 +0900 Subject: [PATCH 15/17] wip --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 20b06e1..b34f4ab 100644 --- a/README.md +++ b/README.md @@ -10,5 +10,5 @@ ## References [1] F. L. Lewis, D. Vrabie, and K. G. Vamvoudakis, “Reinforcement Learning and Feedback Control: Using Natural Decision Methods to Design Optimal Adaptive Controllers,” IEEE Control Syst., vol. 32, no. 6, pp. 76–105, Dec. 2012, doi: 10.1109/MCS.2012.2214134. -[2] D. Vrabie, O. Pastravanu, M. Abu-Khalaf, and F. L. Lewis, “Adaptive Optimal Control for Continuous-Time Linear Systems Based on Policy Iteration,” Automatica, vol. 45, no. 2, pp. 477–484, Feb. 2009, doi: 10.1016/j.automatica.2008.08.017. +[2] D. Vrabie, O. Pastravanu, M. Abu-Khalaf, and F. L. Lewis, “Adaptive Optimal Control for Continuous-Time Linear Systems Based on Policy Iteration,” Automatica, vol. 45, no. 2, pp. 477–484, Feb. 2009, doi: 10.1016/j.automatica.2008.08.017. From b0fcf31f7bb68ea38aa24320231a03b9f9a55dff Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Mon, 3 Jan 2022 19:21:32 +0900 Subject: [PATCH 16/17] Fix a bug --- src/irl/linear_irl.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/irl/linear_irl.jl b/src/irl/linear_irl.jl index c858378..8b7eb63 100644 --- a/src/irl/linear_irl.jl +++ b/src/irl/linear_irl.jl @@ -59,7 +59,7 @@ w: critic parameter (vectorised) - V̂: the vector of approximate values (evaluated) """ function value_iteration!(irl::LinearIRL, buffer::DataBuffer, w; - sc::AbstractStopCondition=DistanceStopCondition(eps), + sc::AbstractStopCondition=DistanceStopCondition(), ) @unpack i, N = irl @unpack data_array = buffer @@ -94,7 +94,7 @@ w: critic parameter (vectorised) - ∫rs: the vector of integral running cost by numerical integration (evaluated) """ function policy_iteration!(irl::LinearIRL, buffer::DataBuffer, w; - sc::AbstractStopCondition=DistanceStopCondition(eps), + sc::AbstractStopCondition=DistanceStopCondition(), ) @unpack i, N = irl @unpack data_array = buffer From 2d0e49a89c344bac08e13ca4073f922ecbb52678 Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Thu, 6 Jan 2022 00:06:55 +0900 Subject: [PATCH 17/17] Rename stopped as terminated --- src/irl/linear_irl.jl | 18 +++++++++--------- src/utils/stop_conditions/stop_conditions.jl | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/irl/linear_irl.jl b/src/irl/linear_irl.jl index 8b7eb63..3dde20c 100644 --- a/src/irl/linear_irl.jl +++ b/src/irl/linear_irl.jl @@ -16,7 +16,7 @@ mutable struct LinearIRL <: AbstractIRL N::Int i::Int i_init::Int - stopped::Bool + terminated::Bool function LinearIRL(Q::AbstractMatrix, R::AbstractMatrix, B::AbstractMatrix; T=0.04, N=nothing, @@ -42,8 +42,8 @@ mutable struct LinearIRL <: AbstractIRL @assert T > 0 && N > 0 R_inv = inv(R) i = i_init - stopped = false - new(Q, R_inv, B, T, N, i, i_init, stopped) + terminated = false + new(Q, R_inv, B, T, N, i, i_init, terminated) end end @@ -75,9 +75,9 @@ function value_iteration!(irl::LinearIRL, buffer::DataBuffer, w; V̂s = ∫rs .+ V̂s_with_prev_P # update the critic parameter w_new = ( hcat(V̂s...) * pinv(hcat(ϕs_prev...)) )'[:] # to reduce the computation time; [:] for vectorisation - if !irl.stopped - if is_stopped(sc, w, w_new) - irl.stopped = true + if !irl.terminated + if is_terminated(sc, w, w_new) + irl.terminated = true else w .= w_new update_index!(irl) @@ -106,9 +106,9 @@ function policy_iteration!(irl::LinearIRL, buffer::DataBuffer, w; ∫rs = data_sorted[end-(N-1):end] |> Map(datum -> datum.∫r) |> collect # update the critic parameter w_new = ( hcat(∫rs...) * pinv(hcat(-Δϕs...)) )'[:] # to reduce the computation time; [:] for vectorisation - if !irl.stopped - if is_stopped(sc, w, w_new) - irl.stopped = true + if !irl.terminated + if is_terminated(sc, w, w_new) + irl.terminated = true else w .= w_new update_index!(irl) diff --git a/src/utils/stop_conditions/stop_conditions.jl b/src/utils/stop_conditions/stop_conditions.jl index 5f754ca..a3077f9 100644 --- a/src/utils/stop_conditions/stop_conditions.jl +++ b/src/utils/stop_conditions/stop_conditions.jl @@ -1,6 +1,6 @@ abstract type AbstractStopCondition end -function is_stopped(sc::AbstractStopCondition, args...; kwrags...) +function is_terminated(sc::AbstractStopCondition, args...; kwrags...) error("Defined this method for type: $(typeof(sc))") end @@ -14,7 +14,7 @@ struct DistanceStopCondition <: AbstractStopCondition end end -function is_stopped(sc::DistanceStopCondition, w, w_new) +function is_terminated(sc::DistanceStopCondition, w, w_new) @unpack eps, p = sc norm(w_new - w, p) < eps end