jlesc17

Solving optimal control problems on GPU with Julia

Jean-Baptiste Caillau, Olivier Cots, Joseph Gergaud, Pierre Martinon, Sophia Sed

affiliations

What it's about

  • Nonlinear optimal control of ODEs:

\[g(x(t_0),x(t_f)) + \int_{t_0}^{t_f} f^0(x(t), u(t))\, \mathrm{d}t \to \min\]

subject to

\[\dot{x}(t) = f(x(t), u(t)),\quad t \in [t_0, t_f]\]

plus boundary, control and state constraints

  • Our core interests: numerical & geometrical methods in control, applications
  • Why Julia: fast (+ JIT), strongly typed, high-level (AD, macros), fast optimisation and ODE solvers available, rapidly growing community
juliacon2025

Discretise then solve strategy (aka direct methods)

  • Discretising an OCP into an NLP: $h_i := t_{i+1}-t_i$,

\[g(X_0,X_N) + \sum_{i=0}^{N} h_i f^0(X_i,U_i) \to \min\]

subject to

\[X_{i+1} - X_i - h_i f(X_i, U_i) = 0,\quad i = 0,\dots,N-1\]

plus other constraints on $X := (X_i)_{i=0,N}$ and $U := (U_i)_{i=0,N}$ such as boundary and path (state and / or control) constraints :

\[b(t_0, X_0, t_N, X_N) = 0\]

\[g(X_i, U_i) = 0,\quad i = 0,\dots,N\]

Simple example, generated code
begin
    #= /data/caillau/CTParser.jl/src/onepass.jl:1003 =#
    function (; scheme = :trapezoidal, grid_size = 200, backend = nothing, init = (0.1, 0.1, 0.1), base_type = Float64)
        #= /data/caillau/CTParser.jl/src/onepass.jl:1003 =#
        #= /data/caillau/CTParser.jl/src/onepass.jl:1004 =#
        LineNumberNode(0, "box constraints: variable")
        #= /data/caillau/CTParser.jl/src/onepass.jl:1005 =#
        begin
            LineNumberNode(0, "box constraints: state")
            begin
                var"##235" = -Inf * ones(3)
                #= /data/caillau/CTParser.jl/src/onepass.jl:461 =#
                var"##236" = Inf * ones(3)
            end
        end
        #= /data/caillau/CTParser.jl/src/onepass.jl:1006 =#
        begin
            LineNumberNode(0, "box constraints: control")
            begin
                var"##237" = -Inf * ones(1)
                #= /data/caillau/CTParser.jl/src/onepass.jl:512 =#
                var"##238" = Inf * ones(1)
            end
        end
        #= /data/caillau/CTParser.jl/src/onepass.jl:1007 =#
        var"##230" = ExaModels.ExaCore(base_type; backend = backend)
        #= /data/caillau/CTParser.jl/src/onepass.jl:1008 =#
        begin
            #= /data/caillau/CTParser.jl/test/test_onepass_exa.jl:23 =#
            var"##232" = begin
                    #= /data/caillau/CTParser.jl/src/onepass.jl:111 =#
                    local ex
                    #= /data/caillau/CTParser.jl/src/onepass.jl:112 =#
                    try
                        #= /data/caillau/CTParser.jl/src/onepass.jl:113 =#
                        (1 - 0) / grid_size
                    catch ex
                        #= /data/caillau/CTParser.jl/src/onepass.jl:115 =#
                        println("Line ", 1, ": ", "(t ∈ [0, 1], time)")
                        #= /data/caillau/CTParser.jl/src/onepass.jl:116 =#
                        throw(ex)
                    end
                end
            #= /data/caillau/CTParser.jl/test/test_onepass_exa.jl:24 =#
            x = begin
                    #= /data/caillau/CTParser.jl/src/onepass.jl:111 =#
                    local ex
                    #= /data/caillau/CTParser.jl/src/onepass.jl:112 =#
                    try
                        #= /data/caillau/CTParser.jl/src/onepass.jl:113 =#
                        ExaModels.variable(var"##230", 3, 0:grid_size; lvar = [var"##235"[i] for (i, j) = Base.product(1:3, 0:grid_size)], uvar = [var"##236"[i] for (i, j) = Base.product(1:3, 0:grid_size)], start = init[2])
                    catch ex
                        #= /data/caillau/CTParser.jl/src/onepass.jl:115 =#
                        println("Line ", 2, ": ", "(x ∈ R ^ 3, state)")
                        #= /data/caillau/CTParser.jl/src/onepass.jl:116 =#
                        throw(ex)
                    end
                end
            #= /data/caillau/CTParser.jl/test/test_onepass_exa.jl:25 =#
            var"u##239" = begin
                    #= /data/caillau/CTParser.jl/src/onepass.jl:111 =#
                    local ex
                    #= /data/caillau/CTParser.jl/src/onepass.jl:112 =#
                    try
                        #= /data/caillau/CTParser.jl/src/onepass.jl:113 =#
                        ExaModels.variable(var"##230", 1, 0:grid_size; lvar = [var"##237"[i] for (i, j) = Base.product(1:1, 0:grid_size)], uvar = [var"##238"[i] for (i, j) = Base.product(1:1, 0:grid_size)], start = init[3])
                    catch ex
                        #= /data/caillau/CTParser.jl/src/onepass.jl:115 =#
                        println("Line ", 3, ": ", "(u ∈ R, control)")
                        #= /data/caillau/CTParser.jl/src/onepass.jl:116 =#
                        throw(ex)
                    end
                end
            #= /data/caillau/CTParser.jl/test/test_onepass_exa.jl:26 =#
            begin
                #= /data/caillau/CTParser.jl/src/onepass.jl:111 =#
                local ex
                #= /data/caillau/CTParser.jl/src/onepass.jl:112 =#
                try
                    #= /data/caillau/CTParser.jl/src/onepass.jl:113 =#
                    ExaModels.constraint(var"##230", (x[i, 0] for i = 1:3); lcon = [-1, 0, 0], ucon = [-1, 0, 0])
                catch ex
                    #= /data/caillau/CTParser.jl/src/onepass.jl:115 =#
                    println("Line ", 4, ": ", "x(0) == [-1, 0, 0]")
                    #= /data/caillau/CTParser.jl/src/onepass.jl:116 =#
                    throw(ex)
                end
            end
            #= /data/caillau/CTParser.jl/test/test_onepass_exa.jl:27 =#
            begin
                #= /data/caillau/CTParser.jl/src/onepass.jl:111 =#
                local ex
                #= /data/caillau/CTParser.jl/src/onepass.jl:112 =#
                try
                    #= /data/caillau/CTParser.jl/src/onepass.jl:113 =#
                    ExaModels.constraint(var"##230", (x[i, grid_size] for i = 1:2); lcon = [0, 0], ucon = [0, 0])
                catch ex
                    #= /data/caillau/CTParser.jl/src/onepass.jl:115 =#
                    println("Line ", 5, ": ", "(x[1:2])(1) == [0, 0]")
                    #= /data/caillau/CTParser.jl/src/onepass.jl:116 =#
                    throw(ex)
                end
            end
            #= /data/caillau/CTParser.jl/test/test_onepass_exa.jl:28 =#
            begin
                #= /data/caillau/CTParser.jl/src/onepass.jl:111 =#
                local ex
                #= /data/caillau/CTParser.jl/src/onepass.jl:112 =#
                try
                    #= /data/caillau/CTParser.jl/src/onepass.jl:113 =#
                    begin
                        #= /data/caillau/CTParser.jl/src/onepass.jl:735 =#
                        if scheme == :trapezoidal
                            #= /data/caillau/CTParser.jl/src/onepass.jl:736 =#
                            ExaModels.constraint(var"##230", ((x[1, j + 1] - x[1, j]) - (var"##232" * (x[2, j] + x[2, j + 1])) / 2 for j = 0:grid_size - 1))
                        elseif #= /data/caillau/CTParser.jl/src/onepass.jl:737 =# scheme == :euler
                            #= /data/caillau/CTParser.jl/src/onepass.jl:738 =#
                            ExaModels.constraint(var"##230", ((x[1, j + 1] - x[1, j]) - var"##232" * x[2, j] for j = 0:grid_size - 1))
                        elseif #= /data/caillau/CTParser.jl/src/onepass.jl:739 =# scheme == :euler_b
                            #= /data/caillau/CTParser.jl/src/onepass.jl:740 =#
                            ExaModels.constraint(var"##230", ((x[1, j + 1] - x[1, j]) - var"##232" * x[2, j + 1] for j = 0:grid_size - 1))
                        else
                            #= /data/caillau/CTParser.jl/src/onepass.jl:742 =#
                            throw("unknown numerical scheme")
                        end
                    end
                catch ex
                    #= /data/caillau/CTParser.jl/src/onepass.jl:115 =#
                    println("Line ", 6, ": ", "(∂(x₁))(t) == x₂(t)")
                    #= /data/caillau/CTParser.jl/src/onepass.jl:116 =#
                    throw(ex)
                end
            end
            #= /data/caillau/CTParser.jl/test/test_onepass_exa.jl:29 =#
            begin
                #= /data/caillau/CTParser.jl/src/onepass.jl:111 =#
                local ex
                #= /data/caillau/CTParser.jl/src/onepass.jl:112 =#
                try
                    #= /data/caillau/CTParser.jl/src/onepass.jl:113 =#
                    begin
                        #= /data/caillau/CTParser.jl/src/onepass.jl:735 =#
                        if scheme == :trapezoidal
                            #= /data/caillau/CTParser.jl/src/onepass.jl:736 =#
                            ExaModels.constraint(var"##230", ((x[2, j + 1] - x[2, j]) - (var"##232" * (var"u##239"[1, j] + var"u##239"[1, j + 1])) / 2 for j = 0:grid_size - 1))
                        elseif #= /data/caillau/CTParser.jl/src/onepass.jl:737 =# scheme == :euler
                            #= /data/caillau/CTParser.jl/src/onepass.jl:738 =#
                            ExaModels.constraint(var"##230", ((x[2, j + 1] - x[2, j]) - var"##232" * var"u##239"[1, j] for j = 0:grid_size - 1))
                        elseif #= /data/caillau/CTParser.jl/src/onepass.jl:739 =# scheme == :euler_b
                            #= /data/caillau/CTParser.jl/src/onepass.jl:740 =#
                            ExaModels.constraint(var"##230", ((x[2, j + 1] - x[2, j]) - var"##232" * var"u##239"[1, j + 1] for j = 0:grid_size - 1))
                        else
                            #= /data/caillau/CTParser.jl/src/onepass.jl:742 =#
                            throw("unknown numerical scheme")
                        end
                    end
                catch ex
                    #= /data/caillau/CTParser.jl/src/onepass.jl:115 =#
                    println("Line ", 7, ": ", "(∂(x₂))(t) == u(t)")
                    #= /data/caillau/CTParser.jl/src/onepass.jl:116 =#
                    throw(ex)
                end
            end
            #= /data/caillau/CTParser.jl/test/test_onepass_exa.jl:30 =#
            begin
                #= /data/caillau/CTParser.jl/src/onepass.jl:111 =#
                local ex
                #= /data/caillau/CTParser.jl/src/onepass.jl:112 =#
                try
                    #= /data/caillau/CTParser.jl/src/onepass.jl:113 =#
                    begin
                        #= /data/caillau/CTParser.jl/src/onepass.jl:735 =#
                        if scheme == :trapezoidal
                            #= /data/caillau/CTParser.jl/src/onepass.jl:736 =#
                            ExaModels.constraint(var"##230", ((x[3, j + 1] - x[3, j]) - (var"##232" * (0.5 * var"u##239"[1, j] ^ 2 + 0.5 * var"u##239"[1, j + 1] ^ 2)) / 2 for j = 0:grid_size - 1))
                        elseif #= /data/caillau/CTParser.jl/src/onepass.jl:737 =# scheme == :euler
                            #= /data/caillau/CTParser.jl/src/onepass.jl:738 =#
                            ExaModels.constraint(var"##230", ((x[3, j + 1] - x[3, j]) - var"##232" * (0.5 * var"u##239"[1, j] ^ 2) for j = 0:grid_size - 1))
                        elseif #= /data/caillau/CTParser.jl/src/onepass.jl:739 =# scheme == :euler_b
                            #= /data/caillau/CTParser.jl/src/onepass.jl:740 =#
                            ExaModels.constraint(var"##230", ((x[3, j + 1] - x[3, j]) - var"##232" * (0.5 * var"u##239"[1, j + 1] ^ 2) for j = 0:grid_size - 1))
                        else
                            #= /data/caillau/CTParser.jl/src/onepass.jl:742 =#
                            throw("unknown numerical scheme")
                        end
                    end
                catch ex
                    #= /data/caillau/CTParser.jl/src/onepass.jl:115 =#
                    println("Line ", 8, ": ", "(∂(x₃))(t) == 0.5 * u(t) ^ 2")
                    #= /data/caillau/CTParser.jl/src/onepass.jl:116 =#
                    throw(ex)
                end
            end
            #= /data/caillau/CTParser.jl/test/test_onepass_exa.jl:31 =#
            begin
                #= /data/caillau/CTParser.jl/src/onepass.jl:111 =#
                local ex
                #= /data/caillau/CTParser.jl/src/onepass.jl:112 =#
                try
                    #= /data/caillau/CTParser.jl/src/onepass.jl:113 =#
                    ExaModels.objective(var"##230", x[3, grid_size])
                catch ex
                    #= /data/caillau/CTParser.jl/src/onepass.jl:115 =#
                    println("Line ", 9, ": ", "x₃(1) → min")
                    #= /data/caillau/CTParser.jl/src/onepass.jl:116 =#
                    throw(ex)
                end
            end
        end
        #= /data/caillau/CTParser.jl/src/onepass.jl:1009 =#
        begin
            #= /data/caillau/CTParser.jl/src/onepass.jl:994 =#
            !(isempty([1, 2, 3])) || throw(CTBase.ParsingError("dynamics not defined"))
            #= /data/caillau/CTParser.jl/src/onepass.jl:995 =#
            sort([1, 2, 3]) == 1:3 || throw(CTBase.ParsingError("some coordinates of dynamics undefined"))
        end
        #= /data/caillau/CTParser.jl/src/onepass.jl:1010 =#
        return ExaModels.ExaModel(var"##230")
    end
end
  • Solving (MadNLP + CUDSS)
This is MadNLP version v0.8.7, running with cuDSS v0.4.0

Number of nonzeros in constraint Jacobian............:    12005
Number of nonzeros in Lagrangian Hessian.............:     9000

Total number of variables............................:     4004
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        0
                     variables with only upper bounds:        0
Total number of equality constraints.................:     3005
Total number of inequality constraints...............:        0
        inequality constraints with only lower bounds:        0
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        0

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  1.0000000e-01 1.10e+00 1.00e+00  -1.0 0.00e+00    -  0.00e+00 0.00e+00   0
   1  1.0001760e-01 1.10e+00 3.84e-03  -1.0 6.88e+02  -4.0 1.00e+00 2.00e-07h  2
   2 -5.2365072e-03 1.89e-02 1.79e-07  -1.0 6.16e+00  -4.5 1.00e+00 1.00e+00h  1
   3  5.9939621e+00 2.28e-03 1.66e-04  -3.8 6.00e+00  -5.0 9.99e-01 1.00e+00h  1
   4  5.9996210e+00 2.94e-06 8.38e-07  -3.8 7.70e-02    -  1.00e+00 1.00e+00h  1

Number of Iterations....: 4

                                   (scaled)                 (unscaled)
Objective...............:   5.9996210189633494e+00    5.9996210189633494e+00
Dual infeasibility......:   8.3756005011360529e-07    8.3756005011360529e-07
Constraint violation....:   2.9426923277963834e-06    2.9426923277963834e-06
Complementarity.........:   2.0007459547789288e-06    2.0007459547789288e-06
Overall NLP error.......:   2.9426923277963834e-06    2.9426923277963834e-06

Number of objective function evaluations             = 6
Number of objective gradient evaluations             = 5
Number of constraint evaluations                     = 6
Number of constraint Jacobian evaluations            = 5
Number of Lagrangian Hessian evaluations             = 4
Total wall-clock secs in solver (w/o fun. eval./lin. alg.)  =  0.072
Total wall-clock secs in linear solver                      =  0.008
Total wall-clock secs in NLP function evaluations           =  0.003
Total wall-clock secs                                       =  0.083
This is MadNLP version v0.8.4, running with cuDSS v0.3.0

Number of nonzeros in constraint Jacobian............:   135017
Number of nonzeros in Lagrangian Hessian.............:   130008

Total number of variables............................:    35008
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        0
                     variables with only upper bounds:        0
Total number of equality constraints.................:    30007
Total number of inequality constraints...............:    15004
        inequality constraints with only lower bounds:     5002
   inequality constraints with lower and upper bounds:    10002
        inequality constraints with only upper bounds:        0

[...]

Number of Iterations....: 35

                                   (scaled)                 (unscaled)
Objective...............:  -1.0142336978192805e+00   -1.0142336978192805e+00
Dual infeasibility......:   4.7384318691001681e-13    4.7384318691001681e-13
Constraint violation....:   1.4068322357215250e-09    1.4068322357215250e-09
Complementarity.........:   9.0909295306344959e-09    9.0909295306344959e-09
Overall NLP error.......:   9.0909295306344959e-09    9.0909295306344959e-09

Number of objective function evaluations             = 36
Number of objective gradient evaluations             = 36
Number of constraint evaluations                     = 36
Number of constraint Jacobian evaluations            = 36
Number of Lagrangian Hessian evaluations             = 35
Total wall-clock secs in solver (w/o fun. eval./lin. alg.)  =  0.911
Total wall-clock secs in linear solver                      =  0.227
Total wall-clock secs in NLP function evaluations           =  0.059
Total wall-clock secs                                       =  1.198

Wrap up

  • High level modelling of optimal control problems
  • Solving on CPU and GPU

Future

  • New applications (space mechanics, biology, quantum mechanics and more)
  • Additional solvers: benchmarking on CPU / GPU for optimisation, Hamiltonian shooting and pathfollowing
  • Improved AD: collab between Argonne and Inria, JLESC Shared Infra AD project...
  • ... and open to contributions! If you like the package, please give us a star ⭐️
OptimalControl.jl

control-toolbox.org

control-toolbox.org

Credits (not exhaustive!)

Acknowledgements

Jean-Baptiste Caillau is partially funded by a France 2030 support managed by the Agence Nationale de la Recherche, under the reference ANR-23-PEIA-0004 (PDE-AI project).

affiliations