// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
#pragma once

#pragma warning( push, 3 )
#pragma warning( disable : 4706 4002 )

#include <boost/exception/all.hpp>
#include <boost/exception_ptr.hpp>
#pragma warning( push, 3 )
#include <boost/thread.hpp>
#pragma warning( pop )

#include <tbb/atomic.h>
#include <tbb/concurrent_queue.h>
#include <tbb/task_scheduler_init.h>

#pragma warning( pop )

namespace frantic {
namespace threads {

#pragma warning( push )
#pragma warning( disable : 4127 )

namespace {

// Copy the exception message into this type because our current compiler does not provide the necessary support to
// capture the original exception across threads. This handler will capture the msg of a std::exception derived
// exception. NOTE: If capturing the exact type of some exceptions is important, make a catch handler above this that
// can handle copying those specific exceptions.
struct thread_exception : virtual boost::exception, virtual std::exception {
    typedef boost::error_info<thread_exception, std::string> info_type;

    virtual const char* what() const throw() {
        if( const std::string* msg = boost::get_error_info<info_type>( *this ) )
            return msg->c_str();
        return "<unknown thread exception>";
    }
};

} // namespace

template <class ProducerConsumerModel>
class buffered_producer_consumer {
    typedef typename ProducerConsumerModel::buffer_type buffer_type;

    /**
     * This class is designed to hold a temporary ptr to an object, deleting it if destroyed while holding an item.
     * I made this instead of using std::unique_ptr because I wanted to be able to set the stored ptr by assigning to
     * a reference (tbb::concurrent_queue::push for example).
     */
    class ptr_guard {
        buffer_type* item;

      public:
        ptr_guard() { item = NULL; }
        ~ptr_guard() {
            if( item )
                delete item;
        }

        inline buffer_type* get() { return item; } // Get a ptr to the currenly owned item
        inline buffer_type*& get_ref() {
            assert( !item );
            return item;
        } // Returns a ref so this pointer can be set. Must not be called with item is non-NULL.
        inline buffer_type* release() {
            buffer_type* result = item;
            item = NULL;
            return result;
        } // Releases ownership of the current item.
    };

    // A counter of the number of active threads. When it becomes zero, all producer threads have finished.
    tbb::atomic<int> numThreadsRemaining;

    // Will be atomically set to true when a thread has registered an exception.
    tbb::atomic<bool> errorOccurred;

    // An exception ptr for transferring the exception (or more likely an approximation) to the main thread.
    boost::exception_ptr error;

    // A pair of queues for passing empty and full buffers between producers and the consumer. The producer threads will
    // be responsible for creating the initial buffers.
    tbb::concurrent_queue<buffer_type*> emptyItems, fullItems;

    template <class T>
    inline static bool concurrent_queue_try_pop( tbb::concurrent_queue<T>& queue, T& outValue ) {
#if TBB_VERSION_MAJOR < 3
        return queue.pop_if_present( outValue );
#else
        return queue.try_pop( outValue );
#endif
    }

    /**
     * This is the code executed by worker threads.
     * @param prod The producer meta-object used to create buffers and Producer::thread_instance objects that do the
     * actual production.
     */
    void producer_fn() {
        try {
            typedef typename ProducerConsumerModel::producer_instance producer_instance;
            producer_instance threadProd( *m_pcModel );

            tbb::task_scheduler_init tsched;

            ptr_guard theItem;

            // Create an extra buffer for later use by this thread.
            theItem.get_ref() = m_pcModel->create_buffer();
            emptyItems.push( theItem.release() );

            // Create the buffer we will be filling.
            theItem.get_ref() = m_pcModel->create_buffer();
            threadProd.init_buffer( theItem.get() );

            while( 1 ) {
                boost::this_thread::interruption_point();

                if( !threadProd.can_produce_more() ) {
                    // This thread should exit since there is no more data to produce. We may need to flush if our
                    // current buffer is non-empty.
                    if( !threadProd.is_buffer_empty( theItem.get() ) ) {
                        threadProd.finish_buffer( theItem.get() );
                        fullItems.push( theItem.release() );
                    } else {
                        emptyItems.push( theItem.release() );
                    }

                    break;
                }

                if( threadProd.is_buffer_full( theItem.get() ) ) {
                    threadProd.finish_buffer( theItem.get() );
                    fullItems.push( theItem.release() );

                    while( !concurrent_queue_try_pop( emptyItems, theItem.get_ref() ) ) {
                        boost::this_thread::interruption_point();
                        boost::this_thread::yield();
                    }

                    threadProd.init_buffer( theItem.get() );
                }

                // We have a non-full buffer so produce data to go into it.
                threadProd.fill_buffer( theItem.get() );
            }
        } catch( const boost::thread_interrupted& ) {
            ; // Do nothing
        } catch( const std::exception& e ) {

            // Atomically determine if another thread has already thrown an exception, and if not store our exception.
            // Otherwise ignore it in favor of the already stored exception.
            if( errorOccurred.compare_and_swap( true, false ) == false )
                error = boost::copy_exception( thread_exception() << thread_exception::info_type( e.what() ) );
        }

        --numThreadsRemaining;
        return;
    }

  private:
    ProducerConsumerModel* m_pcModel;

    boost::thread_group m_threads;

  public:
    buffered_producer_consumer() {
        numThreadsRemaining = 0;
        errorOccurred = false;
    }

    ~buffered_producer_consumer() {
        m_threads.interrupt_all();
        m_threads.join_all();

        buffer_type* item;
        while( concurrent_queue_try_pop( emptyItems, item ) )
            delete item;
        while( concurrent_queue_try_pop( fullItems, item ) )
            delete item;
    }

    void reset( ProducerConsumerModel& pcModel, unsigned int numThreads = 0 ) {
        m_pcModel = &pcModel;

        numThreads = std::max( 1u, numThreads == 0 ? boost::thread::hardware_concurrency() - 1u : numThreads );

        // Set the atomic counter to track the number of outstanding worker threads.
        numThreadsRemaining = numThreads;
        errorOccurred = false;

        for( unsigned int i = 0; i < numThreads; ++i )
            m_threads.create_thread( boost::bind( &buffered_producer_consumer::producer_fn, this ) );
    }

    /**
     * This function will consume data produced by the worker threads until they have all exited.
     * @param cons The consumer implementation object that will process filled buffers.
     */
    void run( bool untilDone = true ) {
        try {
            typedef typename ProducerConsumerModel::consumer_instance consumer_instance;
            consumer_instance threadCons( *m_pcModel );

            ptr_guard theItem;

            do {
                if( !concurrent_queue_try_pop( fullItems, theItem.get_ref() ) ) {
                    threadCons.do_idle_process();

                    while( !concurrent_queue_try_pop( fullItems, theItem.get_ref() ) ) {
                        if( errorOccurred )
                            throw boost::thread_interrupted(); // Can't throw this->error yet, since it may not be
                                                               // finished being created. Throw this instead and sync
                                                               // with all threads.
                        if( numThreadsRemaining == 0 ) {
                            if( concurrent_queue_try_pop(
                                    fullItems,
                                    theItem.get_ref() ) ) // Check again to make sure we didn't get an item before
                                                          // #threads went to 0.
                                break;
                            return;
                        }
                        boost::this_thread::yield();
                    }
                }

                threadCons.consume_buffer( theItem.get() );

                emptyItems.push( theItem.release() );
            } while( untilDone );
        } catch( const boost::thread_interrupted& ) {
            m_threads.interrupt_all();
            m_threads.join_all();

            boost::rethrow_exception( error );
        } catch( ... ) {
            m_threads.interrupt_all();
            m_threads.join_all();

            throw;
        }
    }

    void return_finished_item( std::unique_ptr<buffer_type> item ) { emptyItems.push( item.release() ); }

    std::unique_ptr<buffer_type> steal_finished_item( bool waitForItem = true ) {
        std::unique_ptr<buffer_type> result;

        try {
            buffer_type* theItem = NULL;

            if( waitForItem ) {
                while( !concurrent_queue_try_pop( fullItems, theItem ) ) {
                    if( errorOccurred )
                        throw boost::thread_interrupted(); // Can't throw this->error yet, since it may not be finished
                                                           // being created. Throw this instead and sync with all
                                                           // threads.
                    if( numThreadsRemaining == 0 ) {
                        concurrent_queue_try_pop(
                            fullItems,
                            theItem ); // Check again to make sure we didn't get an item before #threads went to 0.
                        break;
                    }
                    boost::this_thread::yield();
                }
            } else if( !concurrent_queue_try_pop( fullItems, theItem ) && errorOccurred ) {
                throw boost::thread_interrupted(); // Can't throw this->error yet, since it may not be finished being
                                                   // created. Throw this instead and sync with all threads.
            }

            result.reset( theItem );
        } catch( const boost::thread_interrupted& ) {
            m_threads.interrupt_all();
            m_threads.join_all();

            boost::rethrow_exception( error );
        } catch( ... ) {
            m_threads.interrupt_all();
            m_threads.join_all();

            throw;
        }

        return result;
    }

    bool is_done() { return ( numThreadsRemaining == 0 ); }
};

/**
 * This algorithm will operate the Poducer/Consumer pattern for the given template arguments. The construction of the
 * types is defined above.
 *
 * This algorithm will (in parallel) fill buffer objects using ProducerConsumerModel::producer_instance::fill_buffer()
 * and pass them serially to a single Consumer object via. ProducerConsumerModel::consumer_instance::consume_buffer.
 * Once the buffer is consumed it will be made available to another ProducerConsumerModel::producer_instance.
 *
 * @note ProducerConsumerModel::consumer_instance::consume_buffer() will always be called in the context of the thread
 * calling buffered_producer_consumer().
 *
 * @param pcModel An instance of ProducerConsumerModel that supports the interface described above.
 * @param numWorkers The number of worker threads to use, or 0 if the system should decide.
 */
template <class ProducerConsumerModel>
void do_buffered_producer_consumer( ProducerConsumerModel& pcModel, unsigned int numWorkers = 0 ) {
    buffered_producer_consumer<ProducerConsumerModel> theImpl;
    theImpl.reset( pcModel, numWorkers );
    theImpl.run();
}

#pragma warning( pop )

} // namespace threads
} // namespace frantic