Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# bayesplot (development version)

* Added unit tests for `mcmc_areas_ridges_data()`, `mcmc_parcoord_data()`, and `mcmc_trace_data()`.
* Added unit tests for `ppc_error_data()` and `ppc_loo_pit_data()` covering output structure, argument handling, and edge cases.
* Added vignette sections demonstrating `*_data()` companion functions for building custom ggplot2 visualizations (#435)
* Extract `drop_singleton_values()` helper in `mcmc_nuts_treedepth()` to remove duplicated filtering logic.
Expand Down
94 changes: 55 additions & 39 deletions tests/testthat/test-mcmc-intervals.R
Original file line number Diff line number Diff line change
@@ -1,42 +1,5 @@
source(test_path("data-for-mcmc-tests.R"))

test_that("mcmc_intervals_data computes quantiles", {
xs <- melt_mcmc(merge_chains(prepare_mcmc_array(arr, pars = "beta[1]")))
d <- mcmc_intervals_data(arr, pars = "beta[1]",
prob = .3, prob_outer = .5)

qs <- unlist(d[, c("ll", "l", "m", "h", "hh")])
by_hand <- quantile(xs$Value, c(.25, .35, .5, .65, .75))
expect_equal(qs, by_hand, ignore_attr = TRUE)

expect_equal(d$parameter, factor("beta[1]"))
expect_equal(d$outer_width, .5)
expect_equal(d$inner_width, .3)
expect_equal(d$point_est, "median")

d2 <- mcmc_areas_data(arr, pars = "beta[1]", prob = .3, prob_outer = .5)
sets <- split(d2, d2$interval)

expect_equal(range(sets$inner$x), c(d$l, d$h))
expect_equal(range(sets$outer$x), c(d$ll, d$hh))
})

test_that("mcmc_intervals_data computes point estimates", {
xs <- melt_mcmc(merge_chains(prepare_mcmc_array(arr, pars = "beta[2]")))
d <- mcmc_intervals_data(arr, pars = "beta[2]",
prob = .3, prob_outer = .5, point_est = "mean")

expect_equal(d$m, mean(xs$Value), ignore_attr = TRUE)
expect_equal(d$parameter, factor("beta[2]"))
expect_equal(d$point_est, "mean")

d <- mcmc_intervals_data(arr, pars = "(Intercept)",
prob = .3, prob_outer = .5,
point_est = "none")
expect_true(!("m" %in% names(d)))
expect_equal(d$point_est, "none")
})

test_that("mcmc_intervals returns a ggplot object", {
expect_gg(mcmc_intervals(arr, pars = "beta[1]", regex_pars = "x\\:"))
expect_gg(mcmc_intervals(arr1chain, pars = "beta[1]", regex_pars = "Intercept"))
Expand Down Expand Up @@ -115,6 +78,45 @@ test_that("mcmc_intervals/areas with rhat", {
}
})

# _data() tests ----------------------------------------------------------------

test_that("mcmc_intervals_data computes quantiles", {
xs <- melt_mcmc(merge_chains(prepare_mcmc_array(arr, pars = "beta[1]")))
d <- mcmc_intervals_data(arr, pars = "beta[1]",
prob = .3, prob_outer = .5)

qs <- unlist(d[, c("ll", "l", "m", "h", "hh")])
by_hand <- quantile(xs$Value, c(.25, .35, .5, .65, .75))
expect_equal(qs, by_hand, ignore_attr = TRUE)

expect_equal(d$parameter, factor("beta[1]"))
expect_equal(d$outer_width, .5)
expect_equal(d$inner_width, .3)
expect_equal(d$point_est, "median")

d2 <- mcmc_areas_data(arr, pars = "beta[1]", prob = .3, prob_outer = .5)
sets <- split(d2, d2$interval)

expect_equal(range(sets$inner$x), c(d$l, d$h))
expect_equal(range(sets$outer$x), c(d$ll, d$hh))
})

test_that("mcmc_intervals_data computes point estimates", {
xs <- melt_mcmc(merge_chains(prepare_mcmc_array(arr, pars = "beta[2]")))
d <- mcmc_intervals_data(arr, pars = "beta[2]",
prob = .3, prob_outer = .5, point_est = "mean")

expect_equal(d$m, mean(xs$Value), ignore_attr = TRUE)
expect_equal(d$parameter, factor("beta[2]"))
expect_equal(d$point_est, "mean")

d <- mcmc_intervals_data(arr, pars = "(Intercept)",
prob = .3, prob_outer = .5,
point_est = "none")
expect_true(!("m" %in% names(d)))
expect_equal(d$point_est, "none")
})

test_that("mcmc_areas_data computes density", {
areas_data <- mcmc_areas_data(arr, point_est = "none")
areas_data <- areas_data[areas_data$interval_width == 1, ]
Expand Down Expand Up @@ -153,7 +155,7 @@ test_that("compute_column_density can use density options (#118)", {
expect_error(mcmc_areas_data(arr, kernel = stop()))
})

test_that("probabilities outside of [0,1] cause an error", {
test_that("mcmc_intervals_data errors for probabilities outside of [0,1]", {
expect_error(mcmc_intervals_data(arr, prob = -0.1),
"must be in \\[0,1\\]")
expect_error(mcmc_intervals_data(arr, prob = 1.1),
Expand All @@ -164,14 +166,28 @@ test_that("probabilities outside of [0,1] cause an error", {
"must be in \\[0,1\\]")
})

test_that("inconsistent probabilities raise warning (#138)", {
test_that("mcmc_intervals_data warns for inconsistent probabilities (#138)", {
expect_warning(
mcmc_intervals_data(arr, prob = .9, prob_outer = .8),
"`prob_outer` .* is less than `prob`"
)
})


test_that("mcmc_areas_ridges_data returns correct structure", {
d <- mcmc_areas_ridges_data(arr, pars = c("beta[1]", "sigma"), prob = 0.5, prob_outer = 0.9)
expect_s3_class(d, "data.frame")
expect_named(
d,
c(
"parameter", "interval", "interval_width", "x", "density",
"scaled_density", "plotting_density"
)
)
expect_setequal(unique(d$interval), c("inner", "outer"))
expect_false("point" %in% d$interval)
expect_equal(unique(as.character(d$parameter)), c("beta[1]", "sigma"))
})


# Visual tests -----------------------------------------------------------------
Expand Down
38 changes: 36 additions & 2 deletions tests/testthat/test-mcmc-scatter-and-parcoord.R
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ test_that("pairs_condition message if multiple args specified", {
})



# mcmc_parcoord -----------------------------------------------------------
test_that("mcmc_parcoord returns a ggplot object", {
expect_gg(mcmc_parcoord(arr, pars = c("(Intercept)", "sigma")))
Expand Down Expand Up @@ -351,7 +350,6 @@ test_that("mcmc_parcoord throws correct warnings and errors", {
)
})


# parcoord_style_np -------------------------------------------------------
test_that("parcoord_style_np returns correct structure", {
style <- parcoord_style_np()
Expand All @@ -375,6 +373,42 @@ test_that("parcoord_style_np throws correct errors", {
)
})

# mcmc_parcoord_data -------------------------------------------------

test_that("mcmc_parcoord_data returns expected structure", {
d <- mcmc_parcoord_data(arr, pars = c("(Intercept)", "sigma"))
expect_s3_class(d, "data.frame")
expect_named(d, c("Draw", "Parameter", "Value", "Divergent"))

draws_by_parameter <- split(d$Draw, d$Parameter)
expected_draws <- seq_len(dim(arr)[1] * dim(arr)[2])
expect_equal(draws_by_parameter[[1]], expected_draws)
expect_equal(draws_by_parameter[[2]], expected_draws)
})

test_that("mcmc_parcoord_data sets Divergent to 0 when np is NULL", {
d <- mcmc_parcoord_data(arr, pars = c("(Intercept)", "sigma"))
expect_true(all(d$Divergent == 0))
})

test_that("mcmc_parcoord_data joins divergence information from np", {
fake_np <- data.frame(
Iteration = rep(seq_len(dim(arr)[1]), each = dim(arr)[2]),
Chain = rep(seq_len(dim(arr)[2]), times = dim(arr)[1]),
Parameter = factor("divergent__"),
Value = as.integer(rep(c(0, 1, 0, 1), times = dim(arr)[1]))
)
d <- mcmc_parcoord_data(arr, pars = c("(Intercept)", "sigma"), np = fake_np)

expect_false(anyNA(d$Divergent))
expect_equal(sum(d$Divergent == 1), 400)
expect_equal(sum(d$Divergent == 0), 400)
})

test_that("mcmc_parcoord_data errors with fewer than 2 parameters", {
expect_error(mcmc_parcoord_data(arr, pars = "sigma"), "at least two")
})


# Visual tests -----------------------------------------------------------------

Expand Down
37 changes: 37 additions & 0 deletions tests/testthat/test-mcmc-traces.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,44 @@ test_that("mcmc_trace 'np' argument works", {
"No divergences to plot.")
})

# mcmc_trace_data ----------------------------------------------------

test_that("mcmc_trace_data returns plotting data with expected columns", {
d <- mcmc_trace_data(arr, pars = "beta[1]")
expect_s3_class(d, "tbl_df")
expect_named(
d,
c(
"parameter", "value", "value_rank", "iteration", "chain",
"n_chains", "n_iterations", "n_parameters", "highlight", "warmup"
)
)
expect_equal(nrow(d), dim(arr)[1] * dim(arr)[2])
})

test_that("mcmc_trace_data highlight argument works", {
d <- mcmc_trace_data(arr, pars = "beta[1]", highlight = 2)
expect_true(all(d$highlight[d$chain == 2]))
expect_true(all(!d$highlight[d$chain != 2]))
})

test_that("mcmc_trace_data warmup labeling works", {
d <- mcmc_trace_data(arr, pars = "beta[1]", n_warmup = 20)
expect_true(all(d$warmup[d$iteration <= 20]))
expect_true(all(!d$warmup[d$iteration > 20]))
})

test_that("mcmc_trace_data iter1 shifts iterations", {
d <- mcmc_trace_data(arr, pars = "beta[1]", iter1 = 100)
expect_true(min(d$iteration) == 101)
})

test_that("mcmc_trace_data computes value_rank within each parameter", {
d <- mcmc_trace_data(arr, pars = c("beta[1]", "beta[2]"))
observed_ranks <- split(d$value_rank, d$parameter)
expected_ranks <- lapply(split(d$value, d$parameter), rank, ties.method = "average")
expect_equal(observed_ranks, expected_ranks)
})


# Visual tests -----------------------------------------------------------------
Expand Down
Loading