// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace frantic { namespace magma { namespace simple_compiler { /** * Traits classes for converting back and forth between magma_data_type and C++ types. */ template struct traits; /** * @overload For float type. */ template <> struct traits { inline static bool is_compatible( const frantic::magma::magma_data_type& type ) { return type.m_elementType == frantic::channels::data_type_float32 && type.m_elementCount == 1; } inline static frantic::magma::magma_data_type get_type() { return *frantic::magma::magma_singleton::get_named_data_type( _T("Float") ); } static frantic::channels::channel_map s_map; inline static const frantic::channels::channel_map& get_static_map() { if( !s_map.channel_definition_complete() ) { s_map.define_channel( _T("Value") ); s_map.end_channel_definition(); } return s_map; } }; /** * @overload For int type. */ template <> struct traits { inline static bool is_compatible( const frantic::magma::magma_data_type& type ) { return type.m_elementType == frantic::channels::data_type_int32 && type.m_elementCount == 1; } inline static frantic::magma::magma_data_type get_type() { return *frantic::magma::magma_singleton::get_named_data_type( _T("Int") ); } static frantic::channels::channel_map s_map; inline static const frantic::channels::channel_map& get_static_map() { if( !s_map.channel_definition_complete() ) { s_map.define_channel( _T("Value") ); s_map.end_channel_definition(); } return s_map; } }; /** * @overload For bool type. */ template <> struct traits { typedef boost::int_t::fast bool_int_type; inline static bool is_compatible( const frantic::magma::magma_data_type& type ) { return type.m_elementType == frantic::channels::channel_data_type_traits::data_type() && type.m_elementCount == 1; } inline static frantic::magma::magma_data_type get_type() { return *frantic::magma::magma_singleton::get_named_data_type( _T("Bool") ); } static frantic::channels::channel_map s_map; inline static const frantic::channels::channel_map& get_static_map() { if( !s_map.channel_definition_complete() ) { s_map.define_channel( _T("Value") ); s_map.end_channel_definition(); } return s_map; } }; /** * @overload For vec3 type. */ template <> struct traits { inline static bool is_compatible( const frantic::magma::magma_data_type& type ) { return type.m_elementType == frantic::channels::data_type_float32 && type.m_elementCount == 3; } inline static frantic::magma::magma_data_type get_type() { return *frantic::magma::magma_singleton::get_named_data_type( _T("Vec3") ); } static frantic::channels::channel_map s_map; inline static const frantic::channels::channel_map& get_static_map() { if( !s_map.channel_definition_complete() ) { s_map.define_channel( _T("Value") ); s_map.end_channel_definition(); } return s_map; } }; /** * @overload For quat type. */ template <> struct traits { inline static bool is_compatible( const frantic::magma::magma_data_type& type ) { return type.m_elementType == frantic::channels::data_type_float32 && type.m_elementCount == 4; } inline static frantic::magma::magma_data_type get_type() { return *frantic::magma::magma_singleton::get_named_data_type( _T("Quat") ); } static frantic::channels::channel_map s_map; inline static const frantic::channels::channel_map& get_static_map() { if( !s_map.channel_definition_complete() ) { s_map.define_channel( _T("Value") ); s_map.end_channel_definition(); } return s_map; } }; // This namespace stores types needed to implement base_compiler, but that shouldn't be used outside this // implementation. (ie. Don't touch!) namespace detail { /** * Template class for invoking the correct operator() member function of a functor, based on the function signature * provided. * @tparam The signature of the operator() member function to invoke. */ template struct function_invoker; /** * @overload To handle R(P1) type functions. */ template struct function_invoker { template inline static void apply( char* stack, const Functor& fn, std::ptrdiff_t output, const std::ptrdiff_t inputs[] ) { *reinterpret_cast( stack + output ) = fn( *reinterpret_cast( stack + inputs[0] ) ); } }; /** * @overload To handle R(P1,P2) type functions. */ template struct function_invoker { template inline static void apply( char* stack, const Functor& fn, std::ptrdiff_t output, const std::ptrdiff_t inputs[] ) { *reinterpret_cast( stack + output ) = fn( *reinterpret_cast( stack + inputs[0] ), *reinterpret_cast( stack + inputs[1] ) ); } }; /** * @overload To handle R(P1,P2,P3) type functions. */ template struct function_invoker { template inline static void apply( char* stack, const Functor& fn, std::ptrdiff_t output, const std::ptrdiff_t inputs[] ) { *reinterpret_cast( stack + output ) = fn( *reinterpret_cast( stack + inputs[0] ), *reinterpret_cast( stack + inputs[1] ), *reinterpret_cast( stack + inputs[2] ) ); } }; /** * @overload To handle R(P1,P2,P3,P4) type functions. */ template struct function_invoker { template inline static void apply( char* stack, const Functor& fn, std::ptrdiff_t output, const std::ptrdiff_t inputs[] ) { *reinterpret_cast( stack + output ) = fn( *reinterpret_cast( stack + inputs[0] ), *reinterpret_cast( stack + inputs[1] ), *reinterpret_cast( stack + inputs[2] ), *reinterpret_cast( stack + inputs[3] ) ); } }; /** * @overload To handle void(void*,P1) type functions. */ template struct function_invoker { template inline static void apply( char* stack, const Functor& fn, std::ptrdiff_t output, const std::ptrdiff_t inputs[] ) { fn( stack + output, *reinterpret_cast( stack + inputs[0] ) ); } }; /** * @overload To handle void(void*,P1,P2) type functions. */ template struct function_invoker { template inline static void apply( char* stack, const Functor& fn, std::ptrdiff_t output, const std::ptrdiff_t inputs[] ) { fn( stack + output, *reinterpret_cast( stack + inputs[0] ), *reinterpret_cast( stack + inputs[1] ) ); } }; /** * @overload To handle void(void*,P1,P2,P3) type functions. */ template struct function_invoker { template inline static void apply( char* stack, const Functor& fn, std::ptrdiff_t output, const std::ptrdiff_t inputs[] ) { fn( stack + output, *reinterpret_cast( stack + inputs[0] ), *reinterpret_cast( stack + inputs[1] ), *reinterpret_cast( stack + inputs[2] ) ); } }; template class expression_impl : public base_compiler::expression { typedef typename boost::function_types::result_type::type return_type; typedef typename boost::function_types::parameter_types::type param_types; typedef function_invoker invoker; enum { NUM_INPUTS = boost::mpl::size::value }; std::ptrdiff_t m_inputRelPtrs[NUM_INPUTS]; std::ptrdiff_t m_outputRelPtr; Functor m_fn; // The DUMMY template parameter is to avoid complete template // specialization, which is not permitted here in GCC. template struct get_output_map_impl { inline static const frantic::channels::channel_map& apply( const Functor& /*fn*/ ) { return traits::get_static_map(); } }; template struct get_output_map_impl { inline static const frantic::channels::channel_map& apply( const Functor& fn ) { return fn.get_output_map(); } }; // This is non-virtual on purpose! static void internal_apply( const base_compiler::expression* _this, base_compiler::state& data ) { invoker::apply( (char*)data.get_output_pointer( 0 ), // HACK Should be passing state directly and using its get/set member functions. static_cast( _this )->m_fn, static_cast( _this )->m_outputRelPtr, static_cast( _this )->m_inputRelPtrs ); } public: expression_impl( const Functor& fn ) : m_fn( fn ) {} expression_impl( const movable& fn ) : m_fn( fn ) {} // expression_impl( BOOST_RV_REF(Functor) fn ) : m_fn( fn ) //{} virtual ~expression_impl() {} virtual void set_input( std::size_t inputIndex, std::ptrdiff_t relPtr ) { m_inputRelPtrs[inputIndex] = relPtr; } virtual void set_output( std::ptrdiff_t relPtr ) { m_outputRelPtr = relPtr; } virtual const frantic::channels::channel_map& get_output_map() const { // For regular function signatures, we use a static channel_map instance cached in the traits structure. // For void(void*,...) style functions (which indicate multiple and/or complex returned types) we will ask // the implementation functor itself to provide a channel_map. return get_output_map_impl::apply( m_fn ); } virtual void apply( base_compiler::state& data ) const { invoker::apply( (char*)data.get_output_pointer( 0 ) /*HACK for now*/, m_fn, m_outputRelPtr, m_inputRelPtrs ); } // We can avoid virtual function calls by extracting a direct function ptr and storing that alongside the ptr to the // subexpression. This might be quicker since it eliminates an indirection, but it also might not matter much // relative to the actual overhead of calling the function. Gotta test, test, test // // Another way this might get used is if we switch to a LLVM compiled approach. Since some nodes will still be // implemented in MSVC++ we need to use C ptrs in order to cross between LLVM & MSVC since their C++ ABI will not // match (ie. Cannot directly call virtual member functions from LLVM that were compiled in MSVC++). virtual runtime_ptr get_runtime_ptr() const { return &internal_apply; } }; template class constant_expression : public base_compiler::expression { std::ptrdiff_t m_outPtr; T m_value; static void internal_apply( const expression* _this, base_compiler::state& data ) { data.set_temporary( static_cast( _this )->m_outPtr, static_cast( _this )->m_value ); } public: constant_expression( typename boost::call_traits::param_type val ) : m_value( val ) {} virtual ~constant_expression() {} virtual void set_input( std::size_t /*inputIndex*/, std::ptrdiff_t /*relPtr*/ ) { THROW_MAGMA_INTERNAL_ERROR(); } virtual void set_output( std::ptrdiff_t relPtr ) { m_outPtr = relPtr; } virtual const frantic::channels::channel_map& get_output_map() const { return traits::get_static_map(); } virtual void apply( base_compiler::state& data ) const { data.set_temporary( m_outPtr, m_value ); } virtual runtime_ptr get_runtime_ptr() const { return &internal_apply; } }; /** * We need to determine a modified function signature when working with expressions of the form void(void*,...) since * the void* is not a real input, but is instead an output parameter. That needs to be stripped out. * * The baseline case is no modification to the signature. * @tparam The function signature of an expression. */ template struct modified_signature { typedef FnSig type; }; // Handles stripping out the void* output parameter for signatures of this type template struct modified_signature { typedef void( type )( P1 ); }; // Handles stripping out the void* output parameter for signatures of this type template struct modified_signature { typedef void( type )( P1, P2 ); }; // Handles stripping out the void* output parameter for signatures of this type template struct modified_signature { typedef void( type )( P1, P2, P3 ); }; /** * We support tagging a class as movable so it isn't copied on assignment. We need to get at the underlying type though * so this template metafunction strips off the movable tag. The baseline version does nothing. */ template struct remove_movable { typedef T type; }; // Stips the movable wrapper type of the specified type. template struct remove_movable> { typedef T type; }; /** * Checks that the runtime 'inputs' sequence corresponds to the types stored in the metasequence defined by * 'ParamIterator' and 'ParamIteratorEnd'. Returns true if the inputs match the sequence. * @tparam ParamIterator An iterator over an MPL type sequence, pointing to the "current" type we are iterating over. * @tparam ParamIteratorEnd An iterator over an MPL type sequence, pointing to the end of the metasequence we are * iterating over. * @param inputs A pointer to the "current" element in the list of inputs we are matching against. */ template struct check_params { inline static bool apply( const base_compiler::temporary_meta* inputs ) { typedef typename boost::mpl::deref::type CurParamType; if( !traits::is_compatible( inputs->first ) ) return false; return check_params::type, ParamIteratorEnd>::apply( inputs + 1 ); } }; /** * Recursive termination case when the end iterator is reached. */ template struct check_params { inline static bool apply( const base_compiler::temporary_meta* ) { return true; } }; /** * Given a sequence of function signatures (denoted by iterators SigIterator and SigIteratorEnd), finds one that matches * the types of the 'inputs' sequence and allocates a subclass of expression that delegates to the provided functor's * operator() function that matches the selected signature. */ template struct select_binding { inline static std::unique_ptr apply( const Functor& fn, const base_compiler::temporary_meta inputs[], std::size_t numInputs ) { typedef typename boost::mpl::deref::type RealFnSig; // Get the signature from the iterator typedef typename modified_signature::type CurFnSig; // Modify the signature to strip out the void* parameter for the void(void*,...) style signatures. typedef typename boost::function_types::parameter_types::type CurParamTypes; // Convert the signature to a boost::mpl::vector<...> typedef typename boost::mpl::begin::type CurParamsIt; // Get iterator to first element in parameter list typedef typename boost::mpl::end::type CurParamsItEnd; // Get iterator to end of parameter list std::size_t ParamCount = boost::mpl::size::value; // Extract the number of parameters so we can // make sure 'inputs' matches. if( numInputs == ParamCount && check_params::apply( inputs ) ) { typedef typename remove_movable::type RealFunctor; // Strip off a 'movable' wrapper type if // present. std::unique_ptr result( new expression_impl( fn ) ); for( std::size_t i = 0; i < numInputs; ++i ) result->set_input( i, inputs[i].second ); return result; } else { // The inputs didn't match the current signature so advance the signature iterator and check the next one. typedef typename boost::mpl::next::type NextFnSig; return select_binding::apply( fn, inputs, numInputs ); } } }; /** * Recursive termination case when the end iterator is reached. */ template struct select_binding { inline static std::unique_ptr apply( const Functor&, const base_compiler::temporary_meta[], std::size_t ) { return std::unique_ptr(); } }; struct type_collector { std::vector* m_pCollection; type_collector( std::vector& binding ) : m_pCollection( &binding ) {} template void operator()( const T& ) const { m_pCollection->push_back( traits::get_type() ); } }; template struct collect_bindings { static void apply( std::vector>::iterator outBindingsIt ) { typedef typename boost::mpl::deref::type RealFnSig; // Get the signature from the iterator typedef typename modified_signature::type CurFnSig; // Modify the signature to strip out the void* parameter for the void(void*,...) style signatures. typedef typename boost::function_types::parameter_types::type CurParamTypes; // Convert the signature to a boost::mpl::vector<...> boost::mpl::for_each( type_collector( *outBindingsIt ) ); typedef typename boost::mpl::next::type NextFnSig; collect_bindings::apply( ++outBindingsIt ); } }; template struct collect_bindings { static void apply( std::vector>::iterator& ) {} }; } // namespace detail template void base_compiler::compile_impl( expression_id exprID, const typename ExpressionMetaData::type& fn, const temporary_meta inputs[], std::size_t numInputs, std::size_t expectedNumOutputs ) { typedef typename boost::mpl::begin::type BindingsIt; typedef typename boost::mpl::end::type BindingsItEnd; // Create an instance of a subexpression subclass that binds to the input types. Assigns the inputs and returns the // new instance. std::unique_ptr result = detail::select_binding::apply( fn, inputs, numInputs ); if( !result.get() ) { std::vector foundInputs; std::vector> expectedInputs; foundInputs.reserve( numInputs ); expectedInputs.resize( boost::mpl::size::value ); for( std::size_t i = 0; i < numInputs; ++i ) { foundInputs.push_back( inputs[i].first ); if( foundInputs.back().m_typeName == NULL ) { const magma_data_type* pBetterType = magma_singleton::get_matching_data_type( foundInputs.back().m_elementType, foundInputs.back().m_elementCount ); if( pBetterType != NULL ) foundInputs.back() = *pBetterType; } } detail::collect_bindings::apply( expectedInputs.begin() ); BOOST_THROW_EXCEPTION( magma_exception() << magma_exception::node_id( exprID ) << magma_exception::error_name( _T("Invalid input combination") ) << magma_exception::found_inputs( boost::move( foundInputs ) ) << magma_exception::expected_inputs( boost::move( expectedInputs ) ) ); } const frantic::channels::channel_map& resultsMap = result->get_output_map(); if( resultsMap.channel_count() != expectedNumOutputs ) THROW_MAGMA_INTERNAL_ERROR( exprID, expectedNumOutputs, result->get_output_map().channel_count() ); result->set_output( this->allocate_temporaries( exprID, resultsMap ) ); this->register_expression( std::move( result ) ); } template void base_compiler::compile_impl( const frantic::magma::magma_node_base& node, const typename ExpressionMetaData::type& fn ) { if( node.get_num_inputs() != ExpressionMetaData::ARITY ) THROW_MAGMA_INTERNAL_ERROR( node.get_id(), node.get_num_inputs(), (int)ExpressionMetaData::ARITY ); temporary_meta inputs[ExpressionMetaData::ARITY]; for( std::size_t i = 0; i < ExpressionMetaData::ARITY; ++i ) inputs[i] = *this->get_input_value( node, (int)i, false ); return this->compile_impl( node.get_id(), fn, inputs, ExpressionMetaData::ARITY, node.get_num_outputs() ); } namespace detail { template struct type_to_name {}; template <> struct type_to_name { static const frantic::tchar* get_name() { return _T("Geometry"); } }; template <> struct type_to_name { static const frantic::tchar* get_name() { return _T("Particles"); } }; template <> struct type_to_name { static const frantic::tchar* get_name() { return _T("Objects"); } }; }; // namespace detail template Interface* base_compiler::get_interface( expression_id exprID, int index, bool allowNull ) { if( exprID == magma_interface::INVALID_ID ) { if( !allowNull ) throw magma_exception() << magma_exception::error_name( _T("Unconnected input socket") ); return NULL; } std::pair input( exprID, index ); // Interface* result = NULL; // do{ const std::pair& value = this->get_value( input.first, input.second ); if( value.first.m_elementType != frantic::channels::data_type_invalid ) // HACK This should be replaced later with a generalized type that // encapsulates the actual ptr value. throw magma_exception() << magma_exception::connected_id( input.first ) << magma_exception::connected_output_index( input.second ) << magma_exception::found_type( value.first ) << magma_exception::error_name( frantic::tstring() + _T("Expected a ") + detail::type_to_name::get_name() + _T(" node") ); // HACK: This is temporary, but is the general idea of what I want. if( value.first.m_typeName != frantic::tstring( detail::type_to_name::get_name() ) ) throw magma_exception() << magma_exception::connected_id( input.first ) << magma_exception::connected_output_index( input.second ) << magma_exception::found_type( value.first ) << magma_exception::error_name( frantic::tstring() + _T("Expected a ") + detail::type_to_name::get_name() + _T(" node") ); return reinterpret_cast( value.second ); // frantic::magma::magma_node_base* node = m_magma->get_node( input.first ); // if( !node ) // THROW_MAGMA_INTERNAL_ERROR(); // result = dynamic_cast( node ); // if( !result ){ // std::pair prev = input; // if( frantic::magma::nodes::magma_elbow_node* elbow = dynamic_cast( //node ) ) input = elbow->get_input( input.second ); else if( frantic::magma::nodes::magma_blop_node* blop = //dynamic_cast( node ) ) input.first = blop->get__internal_output_id(); // else if( frantic::magma::nodes::magma_blop_input_node* blopSocket = //dynamic_cast( node ) ) input = blopSocket->get_output( input.second //); else if( frantic::magma::nodes::magma_loop_inputs_node* loopInputs = //dynamic_cast( node ) ) input = //loopInputs->get_output_socket_passthrough( input.second ); else throw magma_exception() << //magma_exception::error_name( _T("Incorrect input type") ); // //Make sure the next item is not disconnected. This should have been caught by visit() unless we have a //bug. if( input.first == magma_interface::INVALID_ID ) throw magma_exception() << magma_exception::node_id( //prev.first ) << magma_exception::input_index( prev.second ) << magma_exception::error_name( _T("Unconnected input //socket") ); // } //}while( !result ); // return result; } template Interface* base_compiler::get_input_interface( frantic::magma::magma_node_base& node, int inputIndex, bool allowUnconnected ) { Interface* result = NULL; std::pair input = node.get_input( inputIndex ); if( input.first == magma_interface::INVALID_ID ) { if( !allowUnconnected ) throw magma_exception() << magma_exception::node_id( node.get_id() ) << magma_exception::input_index( inputIndex ) << magma_exception::error_name( _T("Unconnected input socket") ); } else { try { result = this->get_interface( input.first, input.second ); } catch( magma_exception& e ) { if( e.get_node_id() == magma_interface::INVALID_ID ) e << magma_exception::node_id( node.get_id() ) << magma_exception::input_index( inputIndex ); throw; } } return result; } } // namespace simple_compiler } // namespace magma } // namespace frantic