|
2 | 2 | using DiffEqBayes, BenchmarkTools |
3 | 3 |
|
4 | 4 |
|
5 | | -using OrdinaryDiffEq, RecursiveArrayTools, Distributions, ParameterizedFunctions, StanSample, DynamicHMC |
| 5 | +using OrdinaryDiffEq, RecursiveArrayTools, Distributions, ParameterizedFunctions, |
| 6 | + StanSample, DynamicHMC |
6 | 7 | using Plots, StaticArrays, Turing, LinearAlgebra |
7 | 8 |
|
8 | 9 |
|
9 | | -gr(fmt=:png) |
| 10 | +gr(fmt = :png) |
10 | 11 |
|
11 | 12 |
|
12 | 13 | fitz = @ode_def FitzhughNagumo begin |
13 | | - dv = v - 0.33*v^3 -w + l |
14 | | - dw = τinv*(v + a - b*w) |
| 14 | + dv = v - 0.33*v^3 - w + l |
| 15 | + dw = τinv*(v + a - b*w) |
15 | 16 | end a b τinv l |
16 | 17 |
|
17 | 18 |
|
18 | | -prob_ode_fitzhughnagumo = ODEProblem(fitz, [1.0,1.0], (0.0,10.0), [0.7,0.8,1/12.5,0.5]) |
| 19 | +prob_ode_fitzhughnagumo = ODEProblem(fitz, [1.0, 1.0], (0.0, 10.0), [0.7, 0.8, 1/12.5, 0.5]) |
19 | 20 | sol = solve(prob_ode_fitzhughnagumo, Tsit5()) |
20 | 21 |
|
21 | 22 |
|
22 | | -sprob_ode_fitzhughnagumo = ODEProblem{false,SciMLBase.FullSpecialize}(fitz, SA[1.0,1.0], (0.0,10.0), SA[0.7,0.8,1/12.5,0.5]) |
| 23 | +sprob_ode_fitzhughnagumo = ODEProblem{false, SciMLBase.FullSpecialize}( |
| 24 | + fitz, SA[1.0, 1.0], (0.0, 10.0), SA[0.7, 0.8, 1 / 12.5, 0.5]) |
23 | 25 | sol = solve(sprob_ode_fitzhughnagumo, Tsit5()) |
24 | 26 |
|
25 | 27 |
|
26 | | -t = collect(range(1,stop=10,length=10)) |
| 28 | +t = collect(range(1, stop = 10, length = 10)) |
27 | 29 | sig = 0.20 |
28 | 30 | data = convert(Array, VectorOfArray([(sol(t[i]) + sig*randn(2)) for i in 1:length(t)])) |
29 | 31 |
|
30 | 32 |
|
31 | | -scatter(t, data[1,:]) |
32 | | -scatter!(t, data[2,:]) |
| 33 | +scatter(t, data[1, :]) |
| 34 | +scatter!(t, data[2, :]) |
33 | 35 | plot!(sol) |
34 | 36 |
|
35 | 37 |
|
36 | | -priors = [truncated(Normal(1.0,0.5),0,1.5), truncated(Normal(1.0,0.5),0,1.5), truncated(Normal(0.0,0.5),0.0,0.5), truncated(Normal(0.5,0.5),0,1)] |
| 38 | +priors = [truncated(Normal(1.0, 0.5), 0, 1.5), truncated(Normal(1.0, 0.5), 0, 1.5), |
| 39 | + truncated(Normal(0.0, 0.5), 0.0, 0.5), truncated(Normal(0.5, 0.5), 0, 1)] |
37 | 40 |
|
38 | 41 |
|
39 | | -@time bayesian_result_stan = stan_inference(prob_ode_fitzhughnagumo,t,data,priors; delta = 0.65, num_samples = 10_000, print_summary=false, vars=(DiffEqBayes.StanODEData(), InverseGamma(2, 3))) |
| 42 | +@time bayesian_result_stan = stan_inference( |
| 43 | + prob_ode_fitzhughnagumo, t, data, priors; delta = 0.65, num_samples = 10_000, |
| 44 | + print_summary = false, vars = (DiffEqBayes.StanODEData(), InverseGamma(2, 3))) |
40 | 45 |
|
41 | 46 |
|
42 | 47 | @model function fitlv(data, prob) |
43 | 48 |
|
44 | 49 | # Prior distributions. |
45 | 50 | σ ~ InverseGamma(2, 3) |
46 | | - a ~ truncated(Normal(1.0,0.5),0,1.5) |
47 | | - b ~ truncated(Normal(1.0,0.5),0,1.5) |
48 | | - τinv ~ truncated(Normal(0.0,0.5),0.0,0.5) |
49 | | - l ~ truncated(Normal(0.5,0.5),0,1) |
| 51 | + a ~ truncated(Normal(1.0, 0.5), 0, 1.5) |
| 52 | + b ~ truncated(Normal(1.0, 0.5), 0, 1.5) |
| 53 | + τinv ~ truncated(Normal(0.0, 0.5), 0.0, 0.5) |
| 54 | + l ~ truncated(Normal(0.5, 0.5), 0, 1) |
50 | 55 |
|
51 | 56 | # Simulate Lotka-Volterra model. |
52 | | - p = SA[a,b,τinv,l] |
| 57 | + p = SA[a, b, τinv, l] |
53 | 58 | _prob = remake(prob, p = p) |
54 | | - predicted = solve(_prob, Tsit5(); saveat=t) |
| 59 | + predicted = solve(_prob, Tsit5(); saveat = t) |
55 | 60 |
|
56 | 61 | # Observations. |
57 | 62 | for i in 1:length(predicted) |
|
63 | 68 |
|
64 | 69 | model = fitlv(data, sprob_ode_fitzhughnagumo) |
65 | 70 |
|
66 | | -@time chain = sample(model, Turing.NUTS(0.65), 10000; progress=false) |
| 71 | +@time chain = sample(model, Turing.NUTS(0.65), 10000; progress = false) |
67 | 72 |
|
68 | 73 |
|
69 | | -@time bayesian_result_turing = turing_inference(prob_ode_fitzhughnagumo,Tsit5(),t,data,priors;num_samples = 10_000) |
| 74 | +@time bayesian_result_turing = turing_inference( |
| 75 | + prob_ode_fitzhughnagumo, Tsit5(), t, data, priors; num_samples = 10_000) |
70 | 76 |
|
71 | 77 |
|
72 | 78 | using SciMLBenchmarks |
73 | | -SciMLBenchmarks.bench_footer(WEAVE_ARGS[:folder],WEAVE_ARGS[:file]) |
| 79 | +SciMLBenchmarks.bench_footer(WEAVE_ARGS[:folder], WEAVE_ARGS[:file]) |
74 | 80 |
|
0 commit comments