From dd248e9ffe6c19c6470d740b16460ea2a28a1c64 Mon Sep 17 00:00:00 2001 From: Luca Beltrami Date: Fri, 18 Oct 2019 17:07:20 -0700 Subject: [PATCH 1/2] Example of a design for wrapping XAsyncBegin --- Source/HTTP/httpcall.cpp | 798 +++++++++++++++++++++++++++++---------- 1 file changed, 595 insertions(+), 203 deletions(-) diff --git a/Source/HTTP/httpcall.cpp b/Source/HTTP/httpcall.cpp index e92b188a..d6242ae0 100644 --- a/Source/HTTP/httpcall.cpp +++ b/Source/HTTP/httpcall.cpp @@ -5,6 +5,447 @@ #include "httpcall.h" #include "../Mock/lhc_mock.h" +namespace +{ + +class AsyncWrapper +{ +public: + enum class State + { + Created, + Polling, + PendingPoll, + PendingCallback, + Complete, + }; + + AsyncWrapper(_In_ XAsyncBlock* async, size_t rsize, State s): + m_async{ async }, m_resultSize{ rsize }, m_state{ s } + {} + + HRESULT MarkWaitingOnCallback() noexcept + { + assert(m_state == State::Polling); + + m_state = State::PendingCallback; + return S_OK; + } + + HRESULT ScheduleOnQueue() noexcept + { + return ScheduleOnQueueAfter({}); + } + + HRESULT ScheduleOnQueueAfter(std::chrono::milliseconds d) noexcept + { + assert(m_state == State::Polling); + HRESULT hr = XAsyncSchedule(m_async, static_cast(d.count())); + if (FAILED(hr)) + { + return hr; + } + + m_state = State::PendingPoll; + return S_OK; + } + + HRESULT Succeed() noexcept + { + return SucceedWithResultSize(m_resultSize); + } + + HRESULT SucceedWithResultSize(size_t size) noexcept + { + assert(m_state == State::Polling); + + m_state = State::Complete; + XAsyncComplete(m_async, S_OK, size); + + return S_OK; + } + + State GetState() const noexcept + { + return m_state; + } + + XAsyncBlock* GetRaw() const noexcept + { + return m_async; + } + +private: + XAsyncBlock* const m_async; // non owning + size_t const m_resultSize; + State m_state; +}; + +template +struct SizeOfHelper +{ + static constexpr size_t value = sizeof(T); +}; + +template<> +struct SizeOfHelper +{ + static constexpr size_t value = 0; +}; + +template +class AsyncRunnableStatelessBase +{ +public: + static HRESULT Run( + _In_opt_ void* identity, + _In_opt_z_ char const* identityName, + _In_ XAsyncBlock* async, + _In_ TContext* ctx + ) noexcept + { + return XAsyncBegin(async, ctx, identity, identityName, &AsyncRunnableStatelessBase::Provider); + } + +protected: + +protected: // default impls + static void GetResult(TContext*, size_t, void*) + { + assert(false); + } + static void Cancel(TContext*) {} + static void Cleanup(TContext*) {} + +private: + static HRESULT Provider(XAsyncOp op, _In_ XAsyncProviderData const* data) noexcept + { + static_assert(std::is_base_of, TDerived>::value, + "TDerived must be the class deriving from AsyncRunnableStatelessBase"); + + auto ctx = static_cast(data->context); + + switch (op) + { + case XAsyncOp::Begin: + try + { + AsyncWrapper aw{ data->async, SizeOfHelper::value, AsyncWrapper::State::Polling }; + + HRESULT hr = TDerived::Begin(ctx, aw); + if (FAILED(hr)) + { + return hr; + } + + if (aw.GetState() == AsyncWrapper::State::Polling) + { + assert(false); // forgot to complete or schedule again + return E_UNEXPECTED; + } + + return S_OK; + } + catch (...) // needs good catch return setup + { + return E_FAIL; + } + case XAsyncOp::DoWork: + try + { + AsyncWrapper aw{ data->async, SizeOfHelper::value, AsyncWrapper::State::Polling }; + + HRESULT hr = TDerived::DoWork(ctx, aw); + if (FAILED(hr)) + { + assert(hr != E_PENDING); // DoWork should never return E_PENDING + XAsyncComplete(data->async, hr, 0); + return S_OK; + } + + switch (aw.GetState()) + { + case AsyncWrapper::State::Created: + assert(false); // this cannot happen + return E_UNEXPECTED; + case AsyncWrapper::State::Polling: + assert(false); // forgot to complete or schedule again + XAsyncComplete(data->async, E_UNEXPECTED, 0); + return S_OK; + case AsyncWrapper::State::PendingPoll: + case AsyncWrapper::State::PendingCallback: + return E_PENDING; + case AsyncWrapper::State::Complete: + return S_OK; + } + } + catch (...) // needs good catch return setup + { + XAsyncComplete(data->async, E_FAIL, 0); + return S_OK; + } + case XAsyncOp::GetResult: + try + { + TDerived::GetResult(ctx, data->bufferSize, static_cast(data->buffer)); + return S_OK; + } + catch (...) // needs good catch return setup + { + // this seems really bad, is get result allowed to fail? + return E_FAIL; + } + case XAsyncOp::Cancel: + try + { + TDerived::Cancel(ctx); + return S_OK; + } + catch (...) // needs good catch return setup + { + // what to do here? let's assume the task will still complete normally + assert(false); + return S_OK; + } + case XAsyncOp::Cleanup: + // cleanup most definitely should be no fail, die hard on exceptions + { + TDerived::Cleanup(ctx); + return S_OK; + } + } + + // VS can't quite tell that we should always return early + assert(false); + return E_UNEXPECTED; + } +}; + +template +class AsyncRunnableBase +{ +public: + // split to allow other memory management strategies? + template + static HRESULT MakeAndRun( + _In_opt_ void* identity, + _In_opt_z_ char const* identityName, + _In_ XAsyncBlock* async, + TArgs&&... args + ) noexcept + { + try + { + auto self = std::make_unique(std::forward(args)...); + HRESULT hr = XAsyncBegin(async, self.get(), identity, identityName, &AsyncRunnableBase::Provider); + if (FAILED(hr)) + { + return hr; + } + + self.release(); + return S_OK; + } + catch (std::bad_alloc) + { + return E_OUTOFMEMORY; + } + catch (...) // needs good catch return setup + { + return E_FAIL; + } + } + +protected: + + HRESULT MarkWaitingOnCallback() noexcept + { + assert(m_state == State::Polling); + + m_state = State::PendingCallback; + return S_OK; + } + + HRESULT ScheduleOnQueue() noexcept + { + return ScheduleOnQueueAfter({}); + } + + HRESULT ScheduleOnQueueAfter(std::chrono::milliseconds d) noexcept + { + assert(m_state == State::Polling); + HRESULT hr = XAsyncSchedule(m_async, static_cast(d.count())); + if (FAILED(hr)) + { + return hr; + } + + m_state = State::PendingPoll; + return S_OK; + } + + HRESULT Succeed() noexcept + { + return SucceedWithResultSize(SizeOfHelper::value); + } + + HRESULT SucceedWithResultSize(size_t size) noexcept + { + assert(m_state == State::Polling || m_state == State::PendingCallback); + + m_state = State::Complete; + XAsyncComplete(m_async, S_OK, size); + + return S_OK; + } + + HRESULT Fail(HRESULT hr) noexcept + { + assert(FAILED(hr)); + assert(m_state == State::PendingCallback); // don't call Fail from DoWork, just return the failure instead + + m_state = State::Complete; + XAsyncComplete(m_async, hr, 0); + + return S_OK; + } + +protected: // Default impls + void GetResult(size_t, void*) + { + assert(false); + } + void Cancel() {} + +private: + using State = AsyncWrapper::State; + + static HRESULT Provider(XAsyncOp op, _In_ XAsyncProviderData const* data) noexcept + { + static_assert(std::is_base_of, TDerived>::value, + "TDerived must be the class deriving from AsyncRunnableBase"); + + auto self = static_cast(data->context); + assert(self); + self->m_async = data->async; + + switch (op) + { + case XAsyncOp::Begin: + try + { + assert(self->m_state == State::Created); + self->m_state = State::Polling; + + HRESULT hr = self->AsDerived()->Begin(); + if (FAILED(hr)) + { + self->m_state = State::Complete; + return hr; + } + + if (self->m_state == State::Polling) + { + assert(false); + self->m_state = State::Complete; + return E_FAIL; + } + + return S_OK; + } + catch (...) // needs good catch return setup + { + self->m_state = State::Complete; + return E_FAIL; + } + case XAsyncOp::DoWork: + try + { + assert(self->m_state == State::PendingPoll); + self->m_state = State::Polling; + + HRESULT hr = self->AsDerived()->DoWork(); + if (FAILED(hr)) + { + self->m_state = State::Complete; + XAsyncComplete(data->async, hr, 0); + return S_OK; + } + + switch (self->m_state) + { + case AsyncWrapper::State::Created: + assert(false); // this cannot happen + return E_UNEXPECTED; + case AsyncWrapper::State::Polling: + assert(false); // forgot to complete or schedule again + self->m_state = State::Complete; + XAsyncComplete(data->async, E_UNEXPECTED, 0); + return E_FAIL; + case AsyncWrapper::State::PendingPoll: + case AsyncWrapper::State::PendingCallback: + return E_PENDING; + case AsyncWrapper::State::Complete: + return S_OK; + } + } + catch (...) // needs good catch return setup + { + self->m_state = State::Complete; + XAsyncComplete(data->async, E_FAIL, 0); + return S_OK; + } + case XAsyncOp::GetResult: + try + { + assert(self->m_state == State::Complete); + self->AsDerived()->GetResult(data->bufferSize, static_cast(data->buffer)); + return S_OK; + } + catch (...) // needs good catch return setup + { + // this seems really bad, is get result allowed to fail? + return E_UNEXPECTED; + } + case XAsyncOp::Cancel: + try + { + self->AsDerived()->Cancel(); + return S_OK; + } + catch (...) // needs good catch return setup + { + // what to do here? let's assume the task will still complete normally + assert(false); + return S_OK; + } + case XAsyncOp::Cleanup: + // cleanup most definitely should be no fail, die hard on exceptions + { + assert(self->m_state == State::Complete); + + // take ownership of self + std::unique_ptr{ self->AsDerived() }; + return S_OK; + } + } + + // VS can't quite tell that we should always return early + assert(false); + return E_UNEXPECTED; + } + + TDerived* AsDerived() noexcept + { + return static_cast(this); + } + + State m_state = State::Created; + XAsyncBlock* m_async; +}; + +} + using namespace xbox::httpclient; const int MIN_DELAY_FOR_HTTP_INTERNAL_ERROR_IN_MS = 10000; @@ -99,57 +540,51 @@ HRESULT perform_http_call( _Inout_ XAsyncBlock* asyncBlock ) { - HRESULT hr = XAsyncBegin(asyncBlock, call, reinterpret_cast(perform_http_call), __FUNCTION__, - [](XAsyncOp opCode, const XAsyncProviderData* data) + class Runner : public AsyncRunnableStatelessBase { - auto httpSingleton = get_http_singleton(false); - if (nullptr == httpSingleton) + public: + static HRESULT Begin(HCCallHandle call, AsyncWrapper& aw) { - return E_HC_NOT_INITIALISED; + auto httpSingleton = get_http_singleton(false); + if (nullptr == httpSingleton) + { + return E_HC_NOT_INITIALISED; + } + + return aw.ScheduleOnQueueAfter(call->delayBeforeRetry); } - switch (opCode) + static HRESULT DoWork(HCCallHandle call, AsyncWrapper& aw) { - case XAsyncOp::DoWork: + auto httpSingleton = get_http_singleton(false); + if (nullptr == httpSingleton) { - HCCallHandle call = static_cast(data->context); - bool matchedMocks = false; + return E_HC_NOT_INITIALISED; + } - matchedMocks = Mock_Internal_HCHttpCallPerformAsync(call); - if (matchedMocks) - { - XAsyncComplete(data->async, S_OK, 0); - } - else // if there wasn't a matched mock, then real call + bool matchedMocks = false; + + matchedMocks = Mock_Internal_HCHttpCallPerformAsync(call); + if (matchedMocks) + { + return aw.Succeed(); + } + else // if there wasn't a matched mock, then real call + { + HttpPerformInfo const& info = httpSingleton->m_httpPerform; + if (info.handler == nullptr) { - HttpPerformInfo const& info = httpSingleton->m_httpPerform; - if (info.handler != nullptr) - { - try - { - info.handler(call, data->async, info.context, httpSingleton->m_performEnv.get()); - } - catch (...) - { - if (call->traceCall) { HC_TRACE_ERROR(HTTPCLIENT, "HCHttpCallPerform [ID %llu]: failed", static_cast(call)->id); } - } - } + assert(false); + return E_UNEXPECTED; } - return E_PENDING; + info.handler(call, aw.GetRaw(), info.context, httpSingleton->m_performEnv.get()); + return aw.MarkWaitingOnCallback(); } - - default: return S_OK; } - }); - - if (SUCCEEDED(hr)) - { - uint32_t delayInMilliseconds = static_cast(call->delayBeforeRetry.count()); - hr = XAsyncSchedule(asyncBlock, delayInMilliseconds); - } + }; - return hr; + return Runner::Run(reinterpret_cast(perform_http_call), __FUNCTION__, asyncBlock, call); } void clear_http_call_response(_In_ HCCallHandle call) @@ -327,207 +762,164 @@ bool should_fast_fail( } } - -class HcCallWrapper -{ -public: - HcCallWrapper(_In_ HC_CALL* call) - { - assert(call != nullptr); - if (call != nullptr) - { - m_call = HCHttpCallDuplicateHandle(call); - } - } - - ~HcCallWrapper() - { - if (m_call) - { - HCHttpCallCloseHandle(m_call); - } - } - - HC_CALL* get() - { - return m_call; - } - -private: - HC_CALL* m_call{ nullptr }; -}; - -typedef struct retry_context -{ - std::shared_ptr call; - XAsyncBlock* outerAsyncBlock; - XTaskQueueHandle outerQueue; -} retry_context; - -void retry_http_call_until_done( - _In_ retry_context* retryContext - ) +STDAPI +HCHttpCallPerformAsync( + _In_ HCCallHandle call, + _Inout_ XAsyncBlock* asyncBlock + ) noexcept +try { - auto httpSingleton = get_http_singleton(false); - if (nullptr == httpSingleton) + if (call == nullptr) { - XAsyncComplete(retryContext->outerAsyncBlock, S_OK, 0); + return E_INVALIDARG; } - auto requestStartTime = chrono_clock_t::now(); - HC_CALL* call = retryContext->call->get(); - if (call->retryIterationNumber == 0) - { - call->firstRequestStartTime = requestStartTime; - } - call->retryIterationNumber++; - if (call->traceCall) { HC_TRACE_INFORMATION(HTTPCLIENT, "HCHttpCallPerformExecute [ID %llu] Iteration %d", call->id, call->retryIterationNumber); } + if (call->traceCall) { HC_TRACE_INFORMATION(HTTPCLIENT, "HCHttpCallPerform [ID %llu]", call->id); } + call->performCalled = true; - http_retry_after_api_state apiState = httpSingleton->get_retry_state(call->retryAfterCacheId); - if (apiState.statusCode >= 400) + class Runner : public AsyncRunnableBase { - bool clearState = false; - if (should_fast_fail(apiState, call, requestStartTime, &clearState)) + public: + Runner(HCCallHandle call, XTaskQueueHandle queue) noexcept: + m_call{ call }, m_queue{ queue }, m_nestedQueue{ nullptr } { - HCHttpCallResponseSetStatusCode(call, apiState.statusCode); - if (call->traceCall) { HC_TRACE_INFORMATION(HTTPCLIENT, "HCHttpCallPerformExecute [ID %llu] Fast fail %d", call->id, apiState.statusCode); } - XAsyncComplete(retryContext->outerAsyncBlock, S_OK, 0); - return; + assert(m_call); + HCHttpCallDuplicateHandle(m_call); } - if( clearState ) + ~Runner() { - httpSingleton->clear_retry_state(call->retryAfterCacheId); + if (m_nestedQueue) + { + XTaskQueueCloseHandle(m_nestedQueue); + } + HCHttpCallCloseHandle(m_call); } - } - XTaskQueueHandle nestedQueue = nullptr; - if (retryContext->outerQueue != nullptr) - { - XTaskQueuePortHandle workPort; - XTaskQueueGetPort(retryContext->outerQueue, XTaskQueuePort::Work, &workPort); - XTaskQueueCreateComposite(workPort, workPort, &nestedQueue); - } - XAsyncBlock* nestedBlock = new XAsyncBlock{}; - nestedBlock->queue = nestedQueue; - nestedBlock->context = retryContext; - - nestedBlock->callback = [](XAsyncBlock* nestedAsyncBlock) - { - auto httpSingleton = get_http_singleton(false); - if (nullptr == httpSingleton) + HRESULT Begin() { - HC_TRACE_WARNING(HTTPCLIENT, "Http completed after HCCleanup was called. Aborting call."); - return; + return ScheduleOnQueue(); } - auto callStatus = XAsyncGetStatus(nestedAsyncBlock, false); - - retry_context* retryContext = static_cast(nestedAsyncBlock->context); - auto responseReceivedTime = chrono_clock_t::now(); - - uint32_t timeoutWindowInSeconds = 0; - HC_CALL* call = retryContext->call->get(); - HCHttpCallRequestGetTimeoutWindow(call, &timeoutWindowInSeconds); - - if (nestedAsyncBlock->queue != nullptr) + HRESULT DoWork() { - XTaskQueueCloseHandle(nestedAsyncBlock->queue); - } - delete nestedAsyncBlock; + auto httpSingleton = get_http_singleton(false); + if (nullptr == httpSingleton) + { + // todo only fail the first time + return E_HC_NOT_INITIALISED; + } - if (SUCCEEDED(callStatus) && http_call_should_retry(call, responseReceivedTime)) - { - if (call->traceCall) { HC_TRACE_INFORMATION(HTTPCLIENT, "HCHttpCallPerformExecute [ID %llu] Retry after %lld ms", call->id, call->delayBeforeRetry.count()); } - std::lock_guard lock(httpSingleton->m_callRoutedHandlersLock); - for (const auto& pair : httpSingleton->m_callRoutedHandlers) + auto requestStartTime = chrono_clock_t::now(); + if (m_call->retryIterationNumber == 0) { - pair.second.first(call, pair.second.second); + m_call->firstRequestStartTime = requestStartTime; + } + m_call->retryIterationNumber++; + if (m_call->traceCall) + { + HC_TRACE_INFORMATION(HTTPCLIENT, "HCHttpCallPerformExecute [ID %llu] Iteration %d", m_call->id, m_call->retryIterationNumber); } - clear_http_call_response(call); - retry_http_call_until_done(retryContext); - } - else - { - XAsyncComplete(retryContext->outerAsyncBlock, callStatus, 0); - } - }; + http_retry_after_api_state apiState = httpSingleton->get_retry_state(m_call->retryAfterCacheId); + if (apiState.statusCode >= 400) + { + bool clearState = false; + if (should_fast_fail(apiState, m_call, requestStartTime, &clearState)) + { + HCHttpCallResponseSetStatusCode(m_call, apiState.statusCode); + if (m_call->traceCall) + { + HC_TRACE_INFORMATION(HTTPCLIENT, "HCHttpCallPerformExecute [ID %llu] Fast fail %d", m_call->id, apiState.statusCode); + } + return Succeed(); + } - HRESULT hr = perform_http_call(httpSingleton, call, nestedBlock); - if (FAILED(hr)) - { - XAsyncComplete(retryContext->outerAsyncBlock, hr, 0); - return; - } -} + if (clearState) + { + httpSingleton->clear_retry_state(m_call->retryAfterCacheId); + } + } -STDAPI -HCHttpCallPerformAsync( - _In_ HCCallHandle call, - _Inout_ XAsyncBlock* asyncBlock - ) noexcept -try -{ - if (call == nullptr) - { - return E_INVALIDARG; - } + if (m_nestedQueue == nullptr && m_queue != nullptr) + { + XTaskQueuePortHandle workPort; + XTaskQueueGetPort(m_queue, XTaskQueuePort::Work, &workPort); + XTaskQueueCreateComposite(workPort, workPort, &m_nestedQueue); + } - if (call->traceCall) { HC_TRACE_INFORMATION(HTTPCLIENT, "HCHttpCallPerform [ID %llu]", call->id); } - call->performCalled = true; + auto nestedBlock = std::make_unique(); + nestedBlock->queue = m_nestedQueue; + nestedBlock->context = this; - std::shared_ptr retryContext = std::make_shared(); - retryContext->call = std::make_shared(static_cast(call)); // RAII will keep the HCCallHandle alive during HTTP call - retryContext->outerAsyncBlock = asyncBlock; - retryContext->outerQueue = asyncBlock->queue; - retry_context* rawRetryContext = static_cast(shared_ptr_cache::store(retryContext)); - if (rawRetryContext == nullptr) - { - HCHttpCallCloseHandle(call); - return E_HC_NOT_INITIALISED; - } + nestedBlock->callback = [](XAsyncBlock* nestedAsyncBlockPtr) + { + std::unique_ptr nestedAsyncBlock{ nestedAsyncBlockPtr }; + Runner* self = static_cast(nestedAsyncBlock->context); + self->InnerAsyncCallback(); + }; - HRESULT hr = XAsyncBegin(asyncBlock, rawRetryContext, reinterpret_cast(HCHttpCallPerformAsync), __FUNCTION__, - [](_In_ XAsyncOp op, _In_ const XAsyncProviderData* data) - { - auto httpSingleton = get_http_singleton(false); - if (nullptr == httpSingleton) - { - return E_HC_NOT_INITIALISED; + HRESULT hr = perform_http_call(httpSingleton, m_call, nestedBlock.get()); + if (FAILED(hr)) + { + return hr; + } + + nestedBlock.release(); + return MarkWaitingOnCallback(); } - switch (op) + void InnerAsyncCallback() noexcept + try { - case XAsyncOp::DoWork: - retry_http_call_until_done(static_cast(data->context)); - return E_PENDING; + auto httpSingleton = get_http_singleton(false); + if (nullptr == httpSingleton) + { + HC_TRACE_WARNING(HTTPCLIENT, "Http completed after HCCleanup was called. Aborting call."); + return; + } - case XAsyncOp::GetResult: - break; + auto responseReceivedTime = chrono_clock_t::now(); - case XAsyncOp::Cancel: - break; + uint32_t timeoutWindowInSeconds = 0; + HCHttpCallRequestGetTimeoutWindow(m_call, &timeoutWindowInSeconds); - case XAsyncOp::Cleanup: + if (http_call_should_retry(m_call, responseReceivedTime)) { - shared_ptr_cache::remove(data->context); - break; + if (m_call->traceCall) + { + HC_TRACE_INFORMATION(HTTPCLIENT, "HCHttpCallPerformExecute [ID %llu] Retry after %lld ms", m_call->id, m_call->delayBeforeRetry.count()); + } + std::lock_guard lock(httpSingleton->m_callRoutedHandlersLock); + for (const auto& pair : httpSingleton->m_callRoutedHandlers) + { + pair.second.first(m_call, pair.second.second); + } + + clear_http_call_response(m_call); + ScheduleOnQueue(); } - - default: - break; - } - return S_OK; - }); + Succeed(); + } + catch (...) // needs good catch into setup + { + Fail(E_FAIL); + } - if (hr == S_OK) - { - hr = XAsyncSchedule(asyncBlock, 0); - } + private: + HCCallHandle const m_call; + XTaskQueueHandle const m_queue; + XTaskQueueHandle m_nestedQueue; + }; - return hr; + return Runner::MakeAndRun( + reinterpret_cast(HCHttpCallPerformAsync), + __FUNCTION__, + asyncBlock, + call, + asyncBlock->queue + ); } CATCH_RETURN() From 18977c78e488f16a6fc82e994f592f893c9bf2a8 Mon Sep 17 00:00:00 2001 From: Luca Beltrami Date: Fri, 18 Oct 2019 17:38:02 -0700 Subject: [PATCH 2/2] Add some comments --- Source/HTTP/httpcall.cpp | 349 ++++++++++++++++++++++----------------- 1 file changed, 193 insertions(+), 156 deletions(-) diff --git a/Source/HTTP/httpcall.cpp b/Source/HTTP/httpcall.cpp index d6242ae0..65adeb9b 100644 --- a/Source/HTTP/httpcall.cpp +++ b/Source/HTTP/httpcall.cpp @@ -8,21 +8,64 @@ namespace { -class AsyncWrapper +// This code provides 2 class templates called AsyncRunnableBase and +// AsyncRunnableStatelessBase which help implment XAsync providers easily. +// The main goal of the design is to minimize the amount of boilerplate and edge +// cases each provider needs to implement. + +// In both cases the client is supposed to define a class that derives from one +// of the 2 bases and provide at least Begin and DoWork methods (which are +// are called when servicing that specific XAsyncOp). + +// XAsyncRunnableBase is the general purpose helper. It allocates an instance of +// the derived class and passes it down as the provider context, so the client +// can store arbitrary data (which can be passed to its constructor via +// MakeAndRun). +// The client can implement the following 4 provider operations as methods: +// - Begin +// - DoWork +// - GetResult (provided for clients returning void) +// - Cancel (optional) +// The destructor will be called in XAsyncOp::Cleanup to free resources. +// There are a number of protected methods that allow the client to schedule and +// complete the async operation (failure is always signalled by returning an +// error code from one of the provider methods) +template +class AsyncRunnableBase { public: - enum class State + // split to allow other memory management strategies? + template + static HRESULT MakeAndRun( + _In_opt_ void* identity, + _In_opt_z_ char const* identityName, + _In_ XAsyncBlock* async, + TArgs&&... args + ) noexcept { - Created, - Polling, - PendingPoll, - PendingCallback, - Complete, - }; + try + { + auto self = std::make_unique(std::forward(args)...); + HRESULT hr = XAsyncBegin(async, self.get(), identity, identityName, &AsyncRunnableBase::Provider); + if (FAILED(hr)) + { + return hr; + } - AsyncWrapper(_In_ XAsyncBlock* async, size_t rsize, State s): - m_async{ async }, m_resultSize{ rsize }, m_state{ s } - {} + self.release(); + return S_OK; + } + catch (std::bad_alloc) + { + return E_OUTOFMEMORY; + } + catch (...) // needs good catch return setup + { + return E_FAIL; + } + } + +protected: HRESULT MarkWaitingOnCallback() noexcept { @@ -52,12 +95,12 @@ class AsyncWrapper HRESULT Succeed() noexcept { - return SucceedWithResultSize(m_resultSize); + return SucceedWithResultSize(SizeOfHelper::value); } HRESULT SucceedWithResultSize(size_t size) noexcept { - assert(m_state == State::Polling); + assert(m_state == State::Polling || m_state == State::PendingCallback); m_state = State::Complete; XAsyncComplete(m_async, S_OK, size); @@ -65,113 +108,89 @@ class AsyncWrapper return S_OK; } - State GetState() const noexcept - { - return m_state; - } - - XAsyncBlock* GetRaw() const noexcept + HRESULT Fail(HRESULT hr) noexcept { - return m_async; - } - -private: - XAsyncBlock* const m_async; // non owning - size_t const m_resultSize; - State m_state; -}; - -template -struct SizeOfHelper -{ - static constexpr size_t value = sizeof(T); -}; + assert(FAILED(hr)); + assert(m_state == State::PendingCallback); // don't call Fail from DoWork, just return the failure instead -template<> -struct SizeOfHelper -{ - static constexpr size_t value = 0; -}; + m_state = State::Complete; + XAsyncComplete(m_async, hr, 0); -template -class AsyncRunnableStatelessBase -{ -public: - static HRESULT Run( - _In_opt_ void* identity, - _In_opt_z_ char const* identityName, - _In_ XAsyncBlock* async, - _In_ TContext* ctx - ) noexcept - { - return XAsyncBegin(async, ctx, identity, identityName, &AsyncRunnableStatelessBase::Provider); + return S_OK; } -protected: - -protected: // default impls - static void GetResult(TContext*, size_t, void*) +protected: // Default impls + void GetResult(size_t, void*) { assert(false); } - static void Cancel(TContext*) {} - static void Cleanup(TContext*) {} + void Cancel() {} private: + using State = AsyncWrapper::State; + static HRESULT Provider(XAsyncOp op, _In_ XAsyncProviderData const* data) noexcept { - static_assert(std::is_base_of, TDerived>::value, - "TDerived must be the class deriving from AsyncRunnableStatelessBase"); + static_assert(std::is_base_of, TDerived>::value, + "TDerived must be the class deriving from AsyncRunnableBase"); - auto ctx = static_cast(data->context); + auto self = static_cast(data->context); + assert(self); + self->m_async = data->async; switch (op) { case XAsyncOp::Begin: try { - AsyncWrapper aw{ data->async, SizeOfHelper::value, AsyncWrapper::State::Polling }; + assert(self->m_state == State::Created); + self->m_state = State::Polling; - HRESULT hr = TDerived::Begin(ctx, aw); + HRESULT hr = self->AsDerived()->Begin(); if (FAILED(hr)) { + self->m_state = State::Complete; return hr; } - if (aw.GetState() == AsyncWrapper::State::Polling) + if (self->m_state == State::Polling) { - assert(false); // forgot to complete or schedule again - return E_UNEXPECTED; + assert(false); + self->m_state = State::Complete; + return E_FAIL; } return S_OK; } catch (...) // needs good catch return setup { + self->m_state = State::Complete; return E_FAIL; } case XAsyncOp::DoWork: try { - AsyncWrapper aw{ data->async, SizeOfHelper::value, AsyncWrapper::State::Polling }; + assert(self->m_state == State::PendingPoll); + self->m_state = State::Polling; - HRESULT hr = TDerived::DoWork(ctx, aw); + HRESULT hr = self->AsDerived()->DoWork(); if (FAILED(hr)) { - assert(hr != E_PENDING); // DoWork should never return E_PENDING + self->m_state = State::Complete; XAsyncComplete(data->async, hr, 0); return S_OK; } - switch (aw.GetState()) + switch (self->m_state) { case AsyncWrapper::State::Created: assert(false); // this cannot happen return E_UNEXPECTED; case AsyncWrapper::State::Polling: assert(false); // forgot to complete or schedule again + self->m_state = State::Complete; XAsyncComplete(data->async, E_UNEXPECTED, 0); - return S_OK; + return E_FAIL; case AsyncWrapper::State::PendingPoll: case AsyncWrapper::State::PendingCallback: return E_PENDING; @@ -181,24 +200,26 @@ class AsyncRunnableStatelessBase } catch (...) // needs good catch return setup { + self->m_state = State::Complete; XAsyncComplete(data->async, E_FAIL, 0); return S_OK; } case XAsyncOp::GetResult: try { - TDerived::GetResult(ctx, data->bufferSize, static_cast(data->buffer)); + assert(self->m_state == State::Complete); + self->AsDerived()->GetResult(data->bufferSize, static_cast(data->buffer)); return S_OK; } catch (...) // needs good catch return setup { // this seems really bad, is get result allowed to fail? - return E_FAIL; + return E_UNEXPECTED; } case XAsyncOp::Cancel: try { - TDerived::Cancel(ctx); + self->AsDerived()->Cancel(); return S_OK; } catch (...) // needs good catch return setup @@ -210,7 +231,10 @@ class AsyncRunnableStatelessBase case XAsyncOp::Cleanup: // cleanup most definitely should be no fail, die hard on exceptions { - TDerived::Cleanup(ctx); + assert(self->m_state == State::Complete); + + // take ownership of self + std::unique_ptr{ self->AsDerived() }; return S_OK; } } @@ -219,44 +243,32 @@ class AsyncRunnableStatelessBase assert(false); return E_UNEXPECTED; } + + TDerived* AsDerived() noexcept + { + return static_cast(this); + } + + State m_state = State::Created; + XAsyncBlock* m_async; }; -template -class AsyncRunnableBase +// Helper for AsyncRunnableStatelessBase +class AsyncWrapper { public: - // split to allow other memory management strategies? - template - static HRESULT MakeAndRun( - _In_opt_ void* identity, - _In_opt_z_ char const* identityName, - _In_ XAsyncBlock* async, - TArgs&&... args - ) noexcept + enum class State { - try - { - auto self = std::make_unique(std::forward(args)...); - HRESULT hr = XAsyncBegin(async, self.get(), identity, identityName, &AsyncRunnableBase::Provider); - if (FAILED(hr)) - { - return hr; - } - - self.release(); - return S_OK; - } - catch (std::bad_alloc) - { - return E_OUTOFMEMORY; - } - catch (...) // needs good catch return setup - { - return E_FAIL; - } - } + Created, + Polling, + PendingPoll, + PendingCallback, + Complete, + }; -protected: + AsyncWrapper(_In_ XAsyncBlock* async, size_t rsize, State s): + m_async{ async }, m_resultSize{ rsize }, m_state{ s } + {} HRESULT MarkWaitingOnCallback() noexcept { @@ -286,12 +298,12 @@ class AsyncRunnableBase HRESULT Succeed() noexcept { - return SucceedWithResultSize(SizeOfHelper::value); + return SucceedWithResultSize(m_resultSize); } HRESULT SucceedWithResultSize(size_t size) noexcept { - assert(m_state == State::Polling || m_state == State::PendingCallback); + assert(m_state == State::Polling); m_state = State::Complete; XAsyncComplete(m_async, S_OK, size); @@ -299,89 +311,127 @@ class AsyncRunnableBase return S_OK; } - HRESULT Fail(HRESULT hr) noexcept + State GetState() const noexcept { - assert(FAILED(hr)); - assert(m_state == State::PendingCallback); // don't call Fail from DoWork, just return the failure instead + return m_state; + } - m_state = State::Complete; - XAsyncComplete(m_async, hr, 0); + XAsyncBlock* GetRaw() const noexcept + { + return m_async; + } - return S_OK; +private: + XAsyncBlock* const m_async; // non owning + size_t const m_resultSize; + State m_state; +}; + +template +struct SizeOfHelper +{ + static constexpr size_t value = sizeof(T); +}; + +template<> +struct SizeOfHelper +{ + static constexpr size_t value = 0; +}; + +// AsyncRunnableStatelessBase is a specialized helper for building providers +// that do not own their context. Unlike AsyncRunnableBase it does not allocate +// at all, relying on the TContext* object to carry any information it needs. +// The client can implement the following 5 provider operations as static +// methods: +// - Begin +// - DoWork +// - GetResult (provided for clients returning void) +// - Cancel (optional) +// - Cleanup (optional) +// Each of these methods is passed a pointer to the context. Begin and DoWork +// are also given an AsyncWrapper object (by reference) which can be used to +// schedule or complete the operation (like AsyncRunnableBase, returning an +// error code will fail the operation). +template +class AsyncRunnableStatelessBase +{ +public: + static HRESULT Run( + _In_opt_ void* identity, + _In_opt_z_ char const* identityName, + _In_ XAsyncBlock* async, + _In_ TContext* ctx + ) noexcept + { + return XAsyncBegin(async, ctx, identity, identityName, &AsyncRunnableStatelessBase::Provider); } -protected: // Default impls - void GetResult(size_t, void*) +protected: + +protected: // default impls + static void GetResult(TContext*, size_t, void*) { assert(false); } - void Cancel() {} + static void Cancel(TContext*) {} + static void Cleanup(TContext*) {} private: - using State = AsyncWrapper::State; - static HRESULT Provider(XAsyncOp op, _In_ XAsyncProviderData const* data) noexcept { - static_assert(std::is_base_of, TDerived>::value, - "TDerived must be the class deriving from AsyncRunnableBase"); + static_assert(std::is_base_of, TDerived>::value, + "TDerived must be the class deriving from AsyncRunnableStatelessBase"); - auto self = static_cast(data->context); - assert(self); - self->m_async = data->async; + auto ctx = static_cast(data->context); switch (op) { case XAsyncOp::Begin: try { - assert(self->m_state == State::Created); - self->m_state = State::Polling; + AsyncWrapper aw{ data->async, SizeOfHelper::value, AsyncWrapper::State::Polling }; - HRESULT hr = self->AsDerived()->Begin(); + HRESULT hr = TDerived::Begin(ctx, aw); if (FAILED(hr)) { - self->m_state = State::Complete; return hr; } - if (self->m_state == State::Polling) + if (aw.GetState() == AsyncWrapper::State::Polling) { - assert(false); - self->m_state = State::Complete; - return E_FAIL; + assert(false); // forgot to complete or schedule again + return E_UNEXPECTED; } return S_OK; } catch (...) // needs good catch return setup { - self->m_state = State::Complete; return E_FAIL; } case XAsyncOp::DoWork: try { - assert(self->m_state == State::PendingPoll); - self->m_state = State::Polling; + AsyncWrapper aw{ data->async, SizeOfHelper::value, AsyncWrapper::State::Polling }; - HRESULT hr = self->AsDerived()->DoWork(); + HRESULT hr = TDerived::DoWork(ctx, aw); if (FAILED(hr)) { - self->m_state = State::Complete; + assert(hr != E_PENDING); // DoWork should never return E_PENDING XAsyncComplete(data->async, hr, 0); return S_OK; } - switch (self->m_state) + switch (aw.GetState()) { case AsyncWrapper::State::Created: assert(false); // this cannot happen return E_UNEXPECTED; case AsyncWrapper::State::Polling: assert(false); // forgot to complete or schedule again - self->m_state = State::Complete; XAsyncComplete(data->async, E_UNEXPECTED, 0); - return E_FAIL; + return S_OK; case AsyncWrapper::State::PendingPoll: case AsyncWrapper::State::PendingCallback: return E_PENDING; @@ -391,26 +441,24 @@ class AsyncRunnableBase } catch (...) // needs good catch return setup { - self->m_state = State::Complete; XAsyncComplete(data->async, E_FAIL, 0); return S_OK; } case XAsyncOp::GetResult: try { - assert(self->m_state == State::Complete); - self->AsDerived()->GetResult(data->bufferSize, static_cast(data->buffer)); + TDerived::GetResult(ctx, data->bufferSize, static_cast(data->buffer)); return S_OK; } catch (...) // needs good catch return setup { // this seems really bad, is get result allowed to fail? - return E_UNEXPECTED; + return E_FAIL; } case XAsyncOp::Cancel: try { - self->AsDerived()->Cancel(); + TDerived::Cancel(ctx); return S_OK; } catch (...) // needs good catch return setup @@ -422,10 +470,7 @@ class AsyncRunnableBase case XAsyncOp::Cleanup: // cleanup most definitely should be no fail, die hard on exceptions { - assert(self->m_state == State::Complete); - - // take ownership of self - std::unique_ptr{ self->AsDerived() }; + TDerived::Cleanup(ctx); return S_OK; } } @@ -434,14 +479,6 @@ class AsyncRunnableBase assert(false); return E_UNEXPECTED; } - - TDerived* AsDerived() noexcept - { - return static_cast(this); - } - - State m_state = State::Created; - XAsyncBlock* m_async; }; }