Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/stan/services/experimental/advi/fullrank.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include <stan/services/error_codes.hpp>
#include <stan/io/var_context.hpp>
#include <stan/variational/advi.hpp>
#include <stan/services/util/duration_diff.hpp>
#include <stan/services/util/write_timing.hpp>
#include <chrono>
#include <string>
#include <vector>

Expand Down Expand Up @@ -57,6 +60,7 @@ int fullrank(Model& model, const stan::io::var_context& init,
callbacks::writer& init_writer,
callbacks::writer& parameter_writer,
callbacks::writer& diagnostic_writer) {
auto start_time = std::chrono::steady_clock::now();
util::experimental_message(logger);

stan::rng_t rng = util::create_rng(random_seed, chain);
Expand Down Expand Up @@ -94,6 +98,10 @@ int fullrank(Model& model, const stan::io::var_context& init,
return error_codes::SOFTWARE;
}

auto end_time = std::chrono::steady_clock::now();
util::write_timing(util::duration_diff(start_time, end_time), "ADVI",
parameter_writer, logger);

return stan::services::error_codes::OK;
}
} // namespace advi
Expand Down
8 changes: 8 additions & 0 deletions src/stan/services/experimental/advi/meanfield.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include <stan/services/error_codes.hpp>
#include <stan/io/var_context.hpp>
#include <stan/variational/advi.hpp>
#include <stan/services/util/duration_diff.hpp>
#include <stan/services/util/write_timing.hpp>
#include <chrono>
#include <string>
#include <vector>

Expand Down Expand Up @@ -57,6 +60,7 @@ int meanfield(Model& model, const stan::io::var_context& init,
callbacks::writer& init_writer,
callbacks::writer& parameter_writer,
callbacks::writer& diagnostic_writer) {
auto start_time = std::chrono::steady_clock::now();
util::experimental_message(logger);

stan::rng_t rng = util::create_rng(random_seed, chain);
Expand Down Expand Up @@ -93,6 +97,10 @@ int meanfield(Model& model, const stan::io::var_context& init,
return error_codes::SOFTWARE;
}

auto end_time = std::chrono::steady_clock::now();
util::write_timing(util::duration_diff(start_time, end_time), "ADVI",
parameter_writer, logger);

return stan::services::error_codes::OK;
}
} // namespace advi
Expand Down
8 changes: 8 additions & 0 deletions src/stan/services/optimize/bfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#include <stan/optimization/bfgs.hpp>
#include <stan/services/util/initialize.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/services/util/duration_diff.hpp>
#include <stan/services/util/write_timing.hpp>
#include <chrono>
#include <fstream>
#include <iostream>
#include <iomanip>
Expand Down Expand Up @@ -57,6 +60,7 @@ int bfgs(Model& model, const stan::io::var_context& init,
bool save_iterations, int refresh, callbacks::interrupt& interrupt,
callbacks::logger& logger, callbacks::writer& init_writer,
callbacks::writer& parameter_writer) {
auto start_time = std::chrono::steady_clock::now();
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
Expand Down Expand Up @@ -211,6 +215,10 @@ int bfgs(Model& model, const stan::io::var_context& init,
return_code = error_codes::SOFTWARE;
}

auto end_time = std::chrono::steady_clock::now();
util::write_timing(util::duration_diff(start_time, end_time), "Optimization",
parameter_writer, logger);

return return_code;
}

Expand Down
9 changes: 9 additions & 0 deletions src/stan/services/optimize/laplace_sample.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include <stan/math/rev.hpp>
#include <stan/services/error_codes.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/services/util/duration_diff.hpp>
#include <stan/services/util/write_timing.hpp>
#include <chrono>
#include <string>
#include <type_traits>
#include <vector>
Expand Down Expand Up @@ -176,6 +179,7 @@ int laplace_sample(const Model& model, const Eigen::VectorXd& theta_hat,
int refresh, callbacks::interrupt& interrupt,
callbacks::logger& logger, callbacks::writer& sample_writer,
callbacks::structured_writer& hessian_writer) {
auto start_time = std::chrono::steady_clock::now();
try {
internal::laplace_sample<jacobian>(model, theta_hat, draws, calculate_lp,
random_seed, refresh, interrupt, logger,
Expand All @@ -184,6 +188,11 @@ int laplace_sample(const Model& model, const Eigen::VectorXd& theta_hat,
logger.error(e.what());
return error_codes::CONFIG;
}

auto end_time = std::chrono::steady_clock::now();
util::write_timing(util::duration_diff(start_time, end_time),
"Laplace Approximation", sample_writer, logger);

return error_codes::OK;
}

Expand Down
8 changes: 8 additions & 0 deletions src/stan/services/optimize/lbfgs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#include <stan/services/error_codes.hpp>
#include <stan/services/util/initialize.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/services/util/duration_diff.hpp>
#include <stan/services/util/write_timing.hpp>
#include <chrono>
#include <fstream>
#include <iostream>
#include <iomanip>
Expand Down Expand Up @@ -59,6 +62,7 @@ int lbfgs(Model& model, const stan::io::var_context& init,
int refresh, callbacks::interrupt& interrupt,
callbacks::logger& logger, callbacks::writer& init_writer,
callbacks::writer& parameter_writer) {
auto start_time = std::chrono::steady_clock::now();
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
Expand Down Expand Up @@ -206,6 +210,10 @@ int lbfgs(Model& model, const stan::io::var_context& init,
return_code = error_codes::SOFTWARE;
}

auto end_time = std::chrono::steady_clock::now();
util::write_timing(util::duration_diff(start_time, end_time), "Optimization",
parameter_writer, logger);

return return_code;
}

Expand Down
9 changes: 9 additions & 0 deletions src/stan/services/optimize/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include <stan/services/error_codes.hpp>
#include <stan/services/util/initialize.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/services/util/duration_diff.hpp>
#include <stan/services/util/write_timing.hpp>
#include <chrono>
#include <cmath>
#include <limits>
#include <string>
Expand Down Expand Up @@ -45,6 +48,7 @@ int newton(Model& model, const stan::io::var_context& init,
callbacks::interrupt& interrupt, callbacks::logger& logger,
callbacks::writer& init_writer,
callbacks::writer& parameter_writer) {
auto start_time = std::chrono::steady_clock::now();
stan::rng_t rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
Expand Down Expand Up @@ -135,6 +139,11 @@ int newton(Model& model, const stan::io::var_context& init,
values.insert(values.begin(), {lp, static_cast<double>(ret)});
parameter_writer(values);
}

auto end_time = std::chrono::steady_clock::now();
util::write_timing(util::duration_diff(start_time, end_time), "Optimization",
parameter_writer, logger);

return error_codes::OK;
}

Expand Down
51 changes: 51 additions & 0 deletions src/stan/services/util/write_timing.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef STAN_SERVICES_UTIL_WRITE_TIMING_HPP
#define STAN_SERVICES_UTIL_WRITE_TIMING_HPP

#include <stan/callbacks/logger.hpp>
#include <stan/callbacks/writer.hpp>
#include <sstream>
#include <string>

namespace stan {
namespace services {
namespace util {

/**
* Internal method to write timing information to a writer or logger.
*
* @param[in] delta_t time in seconds
* @param[in] label label for the timing info
* @param[in] writer output stream or logger
*/
template <typename F>
void write_timing(double delta_t, const std::string& label, F& writer) {
std::string title(" Elapsed Time: ");
writer("");

std::stringstream ss;
ss << title << delta_t << " seconds (" << label << ")";
writer(ss.str());

writer("");
}

/**
* Write timing information to both writer and logger.
*
* @param[in] delta_t time in seconds
* @param[in] label label for the timing info
* @param[in,out] writer output stream
* @param[in,out] logger messages are written through the logger
*/
inline void write_timing(double delta_t, const std::string& label,
callbacks::writer& writer, callbacks::logger& logger) {
write_timing(delta_t, label, writer);
auto logger_info = [&logger](const std::string& msg) { logger.info(msg); };
write_timing(delta_t, label, logger_info);
}

} // namespace util
} // namespace services
} // namespace stan

#endif
34 changes: 34 additions & 0 deletions src/test/unit/services/experimental/advi/fullrank_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,37 @@ TEST_F(ServicesExperimentalAdvi, fullrank) {

EXPECT_EQ(0, interrupt.call_count());
}

TEST_F(ServicesExperimentalAdvi, fullrank_timing_info) {
unsigned int seed = 0;
unsigned int chain = 1;
double init_radius = 0;
int grad_samples = 1;
int elbo_samples = 100;
int max_iterations = 100;
double tol_rel_obj = 0.01;
double eta = 1.0;
bool adapt_engaged = true;
int adapt_iterations = 50;
int eval_elbo = 100;
int output_samples = 10;

stan::services::experimental::advi::fullrank(
model, context, seed, chain, init_radius, grad_samples, elbo_samples,
max_iterations, tol_rel_obj, eta, adapt_engaged, adapt_iterations,
eval_elbo, output_samples, interrupt, logger, init, parameter,
diagnostic);

EXPECT_TRUE(logger.find_info("Elapsed Time:") > 0)
<< "Should find 'Elapsed Time:' in logger info output";

bool found_in_parameter = false;
for (const auto& msg : parameter.string_values()) {
if (msg.find("Elapsed Time:") != std::string::npos) {
found_in_parameter = true;
break;
}
}
EXPECT_TRUE(found_in_parameter)
<< "Should find 'Elapsed Time:' in parameter_writer output";
}
34 changes: 34 additions & 0 deletions src/test/unit/services/experimental/advi/meanfield_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,37 @@ TEST_F(ServicesExperimentalAdvi, meanfield) {

EXPECT_EQ(0, interrupt.call_count());
}

TEST_F(ServicesExperimentalAdvi, meanfield_timing_info) {
unsigned int seed = 0;
unsigned int chain = 1;
double init_radius = 0;
int grad_samples = 1;
int elbo_samples = 100;
int max_iterations = 100;
double tol_rel_obj = 0.01;
double eta = 1.0;
bool adapt_engaged = true;
int adapt_iterations = 50;
int eval_elbo = 100;
int output_samples = 10;

stan::services::experimental::advi::meanfield(
model, context, seed, chain, init_radius, grad_samples, elbo_samples,
max_iterations, tol_rel_obj, eta, adapt_engaged, adapt_iterations,
eval_elbo, output_samples, interrupt, logger, init, parameter,
diagnostic);

EXPECT_TRUE(logger.find_info("Elapsed Time:") > 0)
<< "Should find 'Elapsed Time:' in logger info output";

bool found_in_parameter = false;
for (const auto& msg : parameter.string_values()) {
if (msg.find("Elapsed Time:") != std::string::npos) {
found_in_parameter = true;
break;
}
}
EXPECT_TRUE(found_in_parameter)
<< "Should find 'Elapsed Time:' in parameter_writer output";
}
21 changes: 21 additions & 0 deletions src/test/unit/services/optimize/bfgs_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,24 @@ TEST_F(ServicesOptimize, rosenbrock) {
EXPECT_FLOAT_EQ(return_code, 0);
EXPECT_EQ(19, interrupt.call_count());
}

TEST_F(ServicesOptimize, bfgs_timing_info) {
unsigned int seed = 0;
unsigned int chain = 1;
double init_radius = 0;

bool save_iterations = true;
int refresh = 0;
stan::test::unit::instrumented_interrupt interrupt;

int return_code = stan::services::optimize::bfgs(
model, context, seed, chain, init_radius, 0.001, 1e-12, 10000, 1e-8,
10000000, 1e-8, 2000, save_iterations, refresh, interrupt, logger, init,
parameter);

EXPECT_EQ(return_code, 0);
EXPECT_TRUE(parameter_ss.str().find("Elapsed Time:") != std::string::npos)
<< "Should find 'Elapsed Time:' in parameter_writer output";
EXPECT_TRUE(logger.find_info("Elapsed Time:") > 0)
<< "Should find 'Elapsed Time:' in logger info output";
}
Loading