Skip to content

Commit 3f05bea

Browse files
committed
Add MWE
1 parent 5548437 commit 3f05bea

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

mwe/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"

mwe/mwe.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using BridgeStan
2+
3+
stan_code = """
4+
functions {
5+
real partial_sum(array[] real slice, int start, int end,
6+
tuple(real, int) params) {
7+
real mu = params.1;
8+
int K = params.2;
9+
real lp = 0;
10+
for (i in 1:size(slice)) {
11+
lp += normal_lpdf(slice[i] | mu, K);
12+
}
13+
return lp;
14+
}
15+
}
16+
data {
17+
int N;
18+
int K;
19+
array[N] real y;
20+
}
21+
parameters {
22+
real mu;
23+
}
24+
model {
25+
mu ~ normal(0, 10);
26+
target += reduce_sum(partial_sum, y, 1, (mu, K));
27+
}
28+
"""
29+
30+
stan_math = ENV["MWE_RUN_DIR"]
31+
label = ENV["MWE_LABEL"]
32+
33+
workdir = mktempdir()
34+
stan_file = joinpath(workdir, "mwe.stan")
35+
write(stan_file, stan_code)
36+
37+
lib = compile_model(stan_file; make_args=["MATH=$stan_math/", "STAN_THREADS=true"])
38+
39+
data = """{"N": 5, "K": 2, "y": [1.0, 2.0, 3.0, 4.0, 5.0]}"""
40+
sm = StanModel(lib, data)
41+
42+
params = [3.0] # mu (unconstrained)
43+
lp = log_density(sm, params)
44+
lp_grad, grad = log_density_gradient(sm, params)
45+
46+
println("[$label] log_density = $lp")
47+
println("[$label] gradient = $grad")
48+
49+
@assert isfinite(lp) "log_density should be finite"
50+
@assert all(isfinite, grad) "gradient should be finite"
51+
@assert lp lp_grad "log_density values should match"
52+
@assert length(grad) == 1 "gradient should have 1 element"
53+
54+
println("[$label] All checks passed!")

0 commit comments

Comments
 (0)