diff --git a/src/stan/services/experimental/advi/fullrank.hpp b/src/stan/services/experimental/advi/fullrank.hpp index 5fba2e4e026..b56a72f3095 100644 --- a/src/stan/services/experimental/advi/fullrank.hpp +++ b/src/stan/services/experimental/advi/fullrank.hpp @@ -10,6 +10,9 @@ #include #include #include +#include +#include +#include #include #include @@ -57,6 +60,7 @@ int fullrank(Model& model, const stan::io::var_context& init, callbacks::writer& init_writer, callbacks::writer& parameter_writer, callbacks::writer& diagnostic_writer) { + auto start_time = std::chrono::steady_clock::now(); util::experimental_message(logger); stan::rng_t rng = util::create_rng(random_seed, chain); @@ -94,6 +98,10 @@ int fullrank(Model& model, const stan::io::var_context& init, return error_codes::SOFTWARE; } + auto end_time = std::chrono::steady_clock::now(); + util::write_timing(util::duration_diff(start_time, end_time), "ADVI", + parameter_writer, logger); + return stan::services::error_codes::OK; } } // namespace advi diff --git a/src/stan/services/experimental/advi/meanfield.hpp b/src/stan/services/experimental/advi/meanfield.hpp index 49bee285058..9c192f4003f 100644 --- a/src/stan/services/experimental/advi/meanfield.hpp +++ b/src/stan/services/experimental/advi/meanfield.hpp @@ -10,6 +10,9 @@ #include #include #include +#include +#include +#include #include #include @@ -57,6 +60,7 @@ int meanfield(Model& model, const stan::io::var_context& init, callbacks::writer& init_writer, callbacks::writer& parameter_writer, callbacks::writer& diagnostic_writer) { + auto start_time = std::chrono::steady_clock::now(); util::experimental_message(logger); stan::rng_t rng = util::create_rng(random_seed, chain); @@ -93,6 +97,10 @@ int meanfield(Model& model, const stan::io::var_context& init, return error_codes::SOFTWARE; } + auto end_time = std::chrono::steady_clock::now(); + util::write_timing(util::duration_diff(start_time, end_time), "ADVI", + parameter_writer, logger); + return stan::services::error_codes::OK; } } // namespace advi diff --git a/src/stan/services/optimize/bfgs.hpp b/src/stan/services/optimize/bfgs.hpp index 6f74388eb11..d6119af26e8 100644 --- a/src/stan/services/optimize/bfgs.hpp +++ b/src/stan/services/optimize/bfgs.hpp @@ -9,6 +9,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -57,6 +60,7 @@ int bfgs(Model& model, const stan::io::var_context& init, bool save_iterations, int refresh, callbacks::interrupt& interrupt, callbacks::logger& logger, callbacks::writer& init_writer, callbacks::writer& parameter_writer) { + auto start_time = std::chrono::steady_clock::now(); stan::rng_t rng = util::create_rng(random_seed, chain); std::vector disc_vector; @@ -211,6 +215,10 @@ int bfgs(Model& model, const stan::io::var_context& init, return_code = error_codes::SOFTWARE; } + auto end_time = std::chrono::steady_clock::now(); + util::write_timing(util::duration_diff(start_time, end_time), "Optimization", + parameter_writer, logger); + return return_code; } diff --git a/src/stan/services/optimize/laplace_sample.hpp b/src/stan/services/optimize/laplace_sample.hpp index 85d31274038..b7bc69a1d28 100644 --- a/src/stan/services/optimize/laplace_sample.hpp +++ b/src/stan/services/optimize/laplace_sample.hpp @@ -8,6 +8,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -176,6 +179,7 @@ int laplace_sample(const Model& model, const Eigen::VectorXd& theta_hat, int refresh, callbacks::interrupt& interrupt, callbacks::logger& logger, callbacks::writer& sample_writer, callbacks::structured_writer& hessian_writer) { + auto start_time = std::chrono::steady_clock::now(); try { internal::laplace_sample(model, theta_hat, draws, calculate_lp, random_seed, refresh, interrupt, logger, @@ -184,6 +188,11 @@ int laplace_sample(const Model& model, const Eigen::VectorXd& theta_hat, logger.error(e.what()); return error_codes::CONFIG; } + + auto end_time = std::chrono::steady_clock::now(); + util::write_timing(util::duration_diff(start_time, end_time), + "Laplace Approximation", sample_writer, logger); + return error_codes::OK; } diff --git a/src/stan/services/optimize/lbfgs.hpp b/src/stan/services/optimize/lbfgs.hpp index b38f79c1c23..f3f05ec74fe 100644 --- a/src/stan/services/optimize/lbfgs.hpp +++ b/src/stan/services/optimize/lbfgs.hpp @@ -9,6 +9,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -59,6 +62,7 @@ int lbfgs(Model& model, const stan::io::var_context& init, int refresh, callbacks::interrupt& interrupt, callbacks::logger& logger, callbacks::writer& init_writer, callbacks::writer& parameter_writer) { + auto start_time = std::chrono::steady_clock::now(); stan::rng_t rng = util::create_rng(random_seed, chain); std::vector disc_vector; @@ -206,6 +210,10 @@ int lbfgs(Model& model, const stan::io::var_context& init, return_code = error_codes::SOFTWARE; } + auto end_time = std::chrono::steady_clock::now(); + util::write_timing(util::duration_diff(start_time, end_time), "Optimization", + parameter_writer, logger); + return return_code; } diff --git a/src/stan/services/optimize/newton.hpp b/src/stan/services/optimize/newton.hpp index 0485281cd48..dd73000c6ba 100644 --- a/src/stan/services/optimize/newton.hpp +++ b/src/stan/services/optimize/newton.hpp @@ -10,6 +10,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -45,6 +48,7 @@ int newton(Model& model, const stan::io::var_context& init, callbacks::interrupt& interrupt, callbacks::logger& logger, callbacks::writer& init_writer, callbacks::writer& parameter_writer) { + auto start_time = std::chrono::steady_clock::now(); stan::rng_t rng = util::create_rng(random_seed, chain); std::vector disc_vector; @@ -135,6 +139,11 @@ int newton(Model& model, const stan::io::var_context& init, values.insert(values.begin(), {lp, static_cast(ret)}); parameter_writer(values); } + + auto end_time = std::chrono::steady_clock::now(); + util::write_timing(util::duration_diff(start_time, end_time), "Optimization", + parameter_writer, logger); + return error_codes::OK; } diff --git a/src/stan/services/util/write_timing.hpp b/src/stan/services/util/write_timing.hpp new file mode 100644 index 00000000000..9313406c0fd --- /dev/null +++ b/src/stan/services/util/write_timing.hpp @@ -0,0 +1,51 @@ +#ifndef STAN_SERVICES_UTIL_WRITE_TIMING_HPP +#define STAN_SERVICES_UTIL_WRITE_TIMING_HPP + +#include +#include +#include +#include + +namespace stan { +namespace services { +namespace util { + +/** + * Internal method to write timing information to a writer or logger. + * + * @param[in] delta_t time in seconds + * @param[in] label label for the timing info + * @param[in] writer output stream or logger + */ +template +void write_timing(double delta_t, const std::string& label, F& writer) { + std::string title(" Elapsed Time: "); + writer(""); + + std::stringstream ss; + ss << title << delta_t << " seconds (" << label << ")"; + writer(ss.str()); + + writer(""); +} + +/** + * Write timing information to both writer and logger. + * + * @param[in] delta_t time in seconds + * @param[in] label label for the timing info + * @param[in,out] writer output stream + * @param[in,out] logger messages are written through the logger + */ +inline void write_timing(double delta_t, const std::string& label, + callbacks::writer& writer, callbacks::logger& logger) { + write_timing(delta_t, label, writer); + auto logger_info = [&logger](const std::string& msg) { logger.info(msg); }; + write_timing(delta_t, label, logger_info); +} + +} // namespace util +} // namespace services +} // namespace stan + +#endif diff --git a/src/test/unit/services/experimental/advi/fullrank_test.cpp b/src/test/unit/services/experimental/advi/fullrank_test.cpp index bb122e6dcb3..dea1a6adf7c 100644 --- a/src/test/unit/services/experimental/advi/fullrank_test.cpp +++ b/src/test/unit/services/experimental/advi/fullrank_test.cpp @@ -95,3 +95,37 @@ TEST_F(ServicesExperimentalAdvi, fullrank) { EXPECT_EQ(0, interrupt.call_count()); } + +TEST_F(ServicesExperimentalAdvi, fullrank_timing_info) { + unsigned int seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int grad_samples = 1; + int elbo_samples = 100; + int max_iterations = 100; + double tol_rel_obj = 0.01; + double eta = 1.0; + bool adapt_engaged = true; + int adapt_iterations = 50; + int eval_elbo = 100; + int output_samples = 10; + + stan::services::experimental::advi::fullrank( + model, context, seed, chain, init_radius, grad_samples, elbo_samples, + max_iterations, tol_rel_obj, eta, adapt_engaged, adapt_iterations, + eval_elbo, output_samples, interrupt, logger, init, parameter, + diagnostic); + + EXPECT_TRUE(logger.find_info("Elapsed Time:") > 0) + << "Should find 'Elapsed Time:' in logger info output"; + + bool found_in_parameter = false; + for (const auto& msg : parameter.string_values()) { + if (msg.find("Elapsed Time:") != std::string::npos) { + found_in_parameter = true; + break; + } + } + EXPECT_TRUE(found_in_parameter) + << "Should find 'Elapsed Time:' in parameter_writer output"; +} diff --git a/src/test/unit/services/experimental/advi/meanfield_test.cpp b/src/test/unit/services/experimental/advi/meanfield_test.cpp index 118b497922f..0c5f04daa53 100644 --- a/src/test/unit/services/experimental/advi/meanfield_test.cpp +++ b/src/test/unit/services/experimental/advi/meanfield_test.cpp @@ -95,3 +95,37 @@ TEST_F(ServicesExperimentalAdvi, meanfield) { EXPECT_EQ(0, interrupt.call_count()); } + +TEST_F(ServicesExperimentalAdvi, meanfield_timing_info) { + unsigned int seed = 0; + unsigned int chain = 1; + double init_radius = 0; + int grad_samples = 1; + int elbo_samples = 100; + int max_iterations = 100; + double tol_rel_obj = 0.01; + double eta = 1.0; + bool adapt_engaged = true; + int adapt_iterations = 50; + int eval_elbo = 100; + int output_samples = 10; + + stan::services::experimental::advi::meanfield( + model, context, seed, chain, init_radius, grad_samples, elbo_samples, + max_iterations, tol_rel_obj, eta, adapt_engaged, adapt_iterations, + eval_elbo, output_samples, interrupt, logger, init, parameter, + diagnostic); + + EXPECT_TRUE(logger.find_info("Elapsed Time:") > 0) + << "Should find 'Elapsed Time:' in logger info output"; + + bool found_in_parameter = false; + for (const auto& msg : parameter.string_values()) { + if (msg.find("Elapsed Time:") != std::string::npos) { + found_in_parameter = true; + break; + } + } + EXPECT_TRUE(found_in_parameter) + << "Should find 'Elapsed Time:' in parameter_writer output"; +} diff --git a/src/test/unit/services/optimize/bfgs_test.cpp b/src/test/unit/services/optimize/bfgs_test.cpp index ac376fe1910..3ba60bf1301 100644 --- a/src/test/unit/services/optimize/bfgs_test.cpp +++ b/src/test/unit/services/optimize/bfgs_test.cpp @@ -58,3 +58,24 @@ TEST_F(ServicesOptimize, rosenbrock) { EXPECT_FLOAT_EQ(return_code, 0); EXPECT_EQ(19, interrupt.call_count()); } + +TEST_F(ServicesOptimize, bfgs_timing_info) { + unsigned int seed = 0; + unsigned int chain = 1; + double init_radius = 0; + + bool save_iterations = true; + int refresh = 0; + stan::test::unit::instrumented_interrupt interrupt; + + int return_code = stan::services::optimize::bfgs( + model, context, seed, chain, init_radius, 0.001, 1e-12, 10000, 1e-8, + 10000000, 1e-8, 2000, save_iterations, refresh, interrupt, logger, init, + parameter); + + EXPECT_EQ(return_code, 0); + EXPECT_TRUE(parameter_ss.str().find("Elapsed Time:") != std::string::npos) + << "Should find 'Elapsed Time:' in parameter_writer output"; + EXPECT_TRUE(logger.find_info("Elapsed Time:") > 0) + << "Should find 'Elapsed Time:' in logger info output"; +} diff --git a/src/test/unit/services/optimize/laplace_sample_test.cpp b/src/test/unit/services/optimize/laplace_sample_test.cpp index e292305cb70..31c81bb51d4 100644 --- a/src/test/unit/services/optimize/laplace_sample_test.cpp +++ b/src/test/unit/services/optimize/laplace_sample_test.cpp @@ -38,7 +38,7 @@ TEST_F(ServicesLaplaceSample, values) { unsigned int seed = 1234; int refresh = 100; std::stringstream sample_ss; - stan::callbacks::stream_writer sample_writer(sample_ss, ""); + stan::callbacks::stream_writer sample_writer(sample_ss, "# "); int return_code = stan::services::laplace_sample( *model, theta_hat, draws, seed, refresh, interrupt, logger, sample_writer); @@ -113,7 +113,7 @@ TEST_F(ServicesLaplaceSample, hessianOutput) { unsigned int seed = 1234; int refresh = 100; std::stringstream sample_ss; - stan::callbacks::stream_writer sample_writer(sample_ss, ""); + stan::callbacks::stream_writer sample_writer(sample_ss, "# "); std::stringstream hessian_ss; stan::callbacks::json_writer hessian_writer{ @@ -138,7 +138,7 @@ TEST_F(ServicesLaplaceSample, noLP) { unsigned int seed = 1234; int refresh = 100; std::stringstream sample_ss; - stan::callbacks::stream_writer sample_writer(sample_ss, ""); + stan::callbacks::stream_writer sample_writer(sample_ss, "# "); stan::callbacks::structured_writer dummy_hessian_writer; int draws = 11; @@ -161,7 +161,7 @@ TEST_F(ServicesLaplaceSample, wrongSizeModeError) { unsigned int seed = 1234; int refresh = 1; std::stringstream sample_ss; - stan::callbacks::stream_writer sample_writer(sample_ss, ""); + stan::callbacks::stream_writer sample_writer(sample_ss, "# "); int RC = stan::services::laplace_sample(*model, theta_hat, draws, seed, refresh, interrupt, logger, sample_writer); @@ -175,7 +175,7 @@ TEST_F(ServicesLaplaceSample, nonPositiveDrawsError) { unsigned int seed = 1234; int refresh = 1; std::stringstream sample_ss; - stan::callbacks::stream_writer sample_writer(sample_ss, ""); + stan::callbacks::stream_writer sample_writer(sample_ss, "# "); int RC = stan::services::laplace_sample(*model, theta_hat, draws, seed, refresh, interrupt, logger, sample_writer); @@ -189,7 +189,7 @@ TEST_F(ServicesLaplaceSample, consoleOutput) { unsigned int seed = 1234; int refresh = 1; std::stringstream sample_ss; - stan::callbacks::stream_writer sample_writer(sample_ss, ""); + stan::callbacks::stream_writer sample_writer(sample_ss, "# "); std::stringstream logger_ss; stan::callbacks::stream_logger sample_logger(logger_ss, logger_ss, logger_ss, logger_ss, logger_ss); @@ -205,3 +205,24 @@ TEST_F(ServicesLaplaceSample, consoleOutput) { EXPECT_EQ(1, count_matches("Generating draws\niteration: 0\niteration: 1", console_str)); } + +TEST_F(ServicesLaplaceSample, laplace_timing_info) { + Eigen::VectorXd theta_hat(2); + theta_hat << 2, 3; + int draws = 10; + unsigned int seed = 1234; + int refresh = 1; + std::stringstream sample_ss; + stan::callbacks::stream_writer sample_writer(sample_ss, "# "); + std::stringstream logger_ss; + stan::callbacks::stream_logger sample_logger(logger_ss, logger_ss, logger_ss, + logger_ss, logger_ss); + int return_code = stan::services::laplace_sample( + *model, theta_hat, draws, seed, refresh, interrupt, sample_logger, + sample_writer); + EXPECT_EQ(stan::services::error_codes::OK, return_code); + EXPECT_TRUE(sample_ss.str().find("Elapsed Time:") != std::string::npos) + << "Should find 'Elapsed Time:' in sample_writer output"; + EXPECT_TRUE(logger_ss.str().find("Elapsed Time:") != std::string::npos) + << "Should find 'Elapsed Time:' in logger output"; +} diff --git a/src/test/unit/services/optimize/lbfgs_test.cpp b/src/test/unit/services/optimize/lbfgs_test.cpp index 243c9b86cc6..67b6b950ca4 100644 --- a/src/test/unit/services/optimize/lbfgs_test.cpp +++ b/src/test/unit/services/optimize/lbfgs_test.cpp @@ -58,3 +58,24 @@ TEST_F(ServicesOptimize, rosenbrock) { EXPECT_FLOAT_EQ(return_code, 0); EXPECT_EQ(22, interrupt.call_count()); } + +TEST_F(ServicesOptimize, lbfgs_timing_info) { + unsigned int seed = 0; + unsigned int chain = 1; + double init_radius = 0; + + bool save_iterations = true; + int refresh = 0; + stan::test::unit::instrumented_interrupt interrupt; + + int return_code = stan::services::optimize::lbfgs( + model, context, seed, chain, init_radius, 5, 0.001, 1e-12, 10000, 1e-8, + 10000000, 1e-8, 2000, save_iterations, refresh, interrupt, logger, init, + parameter); + + EXPECT_EQ(return_code, 0); + EXPECT_TRUE(parameter_ss.str().find("Elapsed Time:") != std::string::npos) + << "Should find 'Elapsed Time:' in parameter_writer output"; + EXPECT_TRUE(logger.find_info("Elapsed Time:") > 0) + << "Should find 'Elapsed Time:' in logger info output"; +} diff --git a/src/test/unit/services/optimize/newton_test.cpp b/src/test/unit/services/optimize/newton_test.cpp index 711756922be..7357ffd6222 100644 --- a/src/test/unit/services/optimize/newton_test.cpp +++ b/src/test/unit/services/optimize/newton_test.cpp @@ -90,3 +90,23 @@ TEST_F(ServicesOptimize, rosenbrock_no_save_iterations) { EXPECT_FLOAT_EQ(return_code, 0); EXPECT_LT(0, interrupt.call_count()); } + +TEST_F(ServicesOptimize, newton_timing_info) { + unsigned int seed = 0; + unsigned int chain = 1; + double init_radius = 0; + + int num_iterations = 1000; + bool save_iterations = true; + stan::test::unit::instrumented_interrupt interrupt; + + int return_code = stan::services::optimize::newton( + model, context, seed, chain, init_radius, num_iterations, save_iterations, + interrupt, logger, init, parameter); + + EXPECT_EQ(return_code, 0); + EXPECT_TRUE(parameter_ss.str().find("Elapsed Time:") != std::string::npos) + << "Should find 'Elapsed Time:' in parameter_writer output"; + EXPECT_TRUE(logger.find_info("Elapsed Time:") > 0) + << "Should find 'Elapsed Time:' in logger info output"; +}