Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 78 additions & 8 deletions stan/math/mix/functor/laplace_marginal_density_estimator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,30 @@ struct NewtonState {
wolfe_status.num_backtracks_ = -1; // Safe initial value for BB step
}

/**
* @brief Constructs Newton state with a consistent (a_init, theta_init) pair.
*
* When the caller supplies a non-zero theta_init, a_init = Sigma^{-1} *
* theta_init must be provided to maintain the invariant theta = Sigma * a.
*
* @param theta_size Dimension of the latent space
* @param obj_fun Objective function: (a, theta) -> double
* @param theta_grad_f Gradient function: theta -> grad
* @param a_init Initial a value consistent with theta_init
* @param theta_init Initial theta value
*/
template <typename ObjFun, typename ThetaGradFun, typename ThetaInitializer>
NewtonState(int theta_size, ObjFun&& obj_fun, ThetaGradFun&& theta_grad_f,
const Eigen::VectorXd& a_init, ThetaInitializer&& theta_init)
: wolfe_info(std::forward<ObjFun>(obj_fun), a_init,
std::forward<ThetaInitializer>(theta_init),
std::forward<ThetaGradFun>(theta_grad_f), 0),
b(theta_size),
B(theta_size, theta_size),
prev_g(theta_size) {
wolfe_status.num_backtracks_ = -1; // Safe initial value for BB step
}

/**
* @brief Access the current step state (mutable).
* @return Reference to current WolfeStep
Expand Down Expand Up @@ -426,9 +450,13 @@ inline void llt_with_jitter(LLT& llt_B, B_t& B, double min_jitter = 1e-10,
double max_jitter = 1e-5) {
llt_B.compute(B);
if (llt_B.info() != Eigen::Success) {
double prev_jitter = 0.0;
double jitter_try = min_jitter;
for (; jitter_try < max_jitter; jitter_try *= 10) {
B.diagonal().array() += jitter_try;
// Remove previously added jitter before adding the new (larger) amount,
// so that the total diagonal perturbation is exactly jitter_try.
B.diagonal().array() += (jitter_try - prev_jitter);
prev_jitter = jitter_try;
llt_B.compute(B);
if (llt_B.info() == Eigen::Success) {
break;
Expand Down Expand Up @@ -935,6 +963,9 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
scratch.alpha() = 1.0;
update_fun(scratch, state.curr(), state.prev(), scratch.eval_,
state.wolfe_info.p_);
// Save the full Newton step objective before the Wolfe line search
// overwrites scratch with intermediate trial points.
const double full_newton_obj = scratch.eval_.obj();
if (scratch.alpha() <= options.line_search.min_alpha) {
state.wolfe_status.accept_ = false;
finish_update = true;
Expand All @@ -953,15 +984,40 @@ inline auto run_newton_loop(SolverPolicy& solver, NewtonStateT& state,
state.wolfe_status = internal::wolfe_line_search(
state.wolfe_info, update_fun, options.line_search, msgs);
}
// When the Wolfe line search rejects, don't immediately terminate.
// Instead, let the Newton loop try at least one more iteration.
// The original code compared the stale curr.obj() (which equalled
// prev.obj() after the swap in update_next_step) and would always
// terminate on ANY Wolfe rejection — even on the very first Newton
// step. Now we only declare search_failed if the full Newton step
// itself didn't improve the objective.
bool search_failed;
if (!state.wolfe_status.accept_) {
if (full_newton_obj > state.prev().obj()) {
// The full Newton step (evaluated before Wolfe ran) improved
// the objective. Re-evaluate scratch at the full Newton step
// so we can accept it as the current iterate.
scratch.eval_.alpha() = 1.0;
update_fun(scratch, state.curr(), state.prev(), scratch.eval_,
state.wolfe_info.p_);
state.curr().update(scratch);
state.wolfe_status.accept_ = true;
search_failed = false;
} else {
search_failed = true;
}
} else {
search_failed = false;
}
Comment on lines +987 to +1011
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get what this is trying to do, though I think it caught the main issue that curr can be stale, which it called out but then did not do anything about. I was being memory greedy / too clever here and reusing the memory in curr in a few places I should not have.

Though I don't disagree with what it is doing here. If we are in a loop with a small'ish step size and a problem with a weird geometry it can be worth yolo'ing a large step size to try to jump out of it. The code actually used to have something like this in it, but at the time I dismissed it. We also test alpha = 1 a few lines above and should be reusing those results here instead of rerunning update_fun. update_fun checks if the objective or gradient is NaN and reduces the step size from 1 until we get a finite objective and gradient.

I'm going to modify this, but will keep the backup newton step check.

/**
* Stop when objective change is small, or when a rejected Wolfe step
* fails to improve; finish_update then exits the Newton loop.
* Stop when objective change is small (absolute AND relative), or when
* a rejected Wolfe step fails to improve; finish_update then exits the
* Newton loop.
*/
double obj_change = std::abs(state.curr().obj() - state.prev().obj());
bool objective_converged
= std::abs(state.curr().obj() - state.prev().obj())
< options.tolerance;
bool search_failed = (!state.wolfe_status.accept_
&& state.curr().obj() <= state.prev().obj());
= obj_change < options.tolerance
&& obj_change < options.tolerance * std::abs(state.prev().obj());
finish_update = objective_converged || search_failed;
}
if (finish_update) {
Expand Down Expand Up @@ -1152,7 +1208,21 @@ inline auto laplace_marginal_density_est(
return laplace_likelihood::theta_grad(ll_fun, theta_val, ll_args, msgs);
};
decltype(auto) theta_init = theta_init_impl<InitTheta>(theta_size, options);
internal::NewtonState state(theta_size, obj_fun, theta_grad_f, theta_init);
// When the user supplies a non-zero theta_init, we must initialise a
// consistently so that the invariant theta = Sigma * a holds. Otherwise
// the prior term -0.5 * a'*theta vanishes (a=0 while theta!=0), inflating
// the initial objective and causing the Wolfe line search to reject the
// first Newton step.
auto make_state = [&](auto&& theta_0) {
if constexpr (InitTheta) {
Eigen::VectorXd a_init = covariance.llt().solve(Eigen::VectorXd(theta_0));
return internal::NewtonState(theta_size, obj_fun, theta_grad_f, a_init,
theta_0);
} else {
return internal::NewtonState(theta_size, obj_fun, theta_grad_f, theta_0);
}
};
auto state = make_state(theta_init);
// Start with safe step size
auto update_fun = create_update_fun(
std::move(obj_fun), std::move(theta_grad_f), covariance, options);
Expand Down
28 changes: 26 additions & 2 deletions stan/math/mix/functor/wolfe_line_search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,29 @@ struct WolfeInfo {
"theta and likelihood arguments.");
}
}
/**
* Construct WolfeInfo with a consistent (a_init, theta_init) pair.
*
* When the caller supplies a non-zero theta_init, the corresponding
* a_init = Sigma^{-1} * theta_init must be provided so that the
* invariant theta = Sigma * a holds at initialization. This avoids
* an inflated initial objective (the prior term -0.5 * a'*theta would
* otherwise vanish when a is zero but theta is not).
*/
template <typename ObjFun, typename Theta0, typename ThetaGradF>
WolfeInfo(ObjFun&& obj_fun, const Eigen::VectorXd& a_init, Theta0&& theta0,
ThetaGradF&& theta_grad_f, int /*tag*/)
: curr_(std::forward<ObjFun>(obj_fun), a_init,
std::forward<Theta0>(theta0),
std::forward<ThetaGradF>(theta_grad_f)),
prev_(curr_),
scratch_(a_init.size()) {
if (!std::isfinite(curr_.obj())) {
throw std::domain_error(
"laplace_marginal_density: log likelihood is not finite at initial "
"theta and likelihood arguments.");
}
}
WolfeInfo(WolfeData&& curr, WolfeData&& prev)
: curr_(std::move(curr)),
prev_(std::move(prev)),
Expand Down Expand Up @@ -902,9 +925,10 @@ inline WolfeStatus wolfe_line_search(Info& wolfe_info, UpdateFun&& update_fun,
} else { // [3]
high = mid;
}
} else {
// [4]
high = mid;
}
// [4]
high = mid;
} else {
// [5]
high = mid;
Expand Down