/* * All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or * its licensors. * * For complete copyright and license terms please see the LICENSE at the root of this * distribution (the "License"). All use of this software is governed by the License, * or, if provided, by the license below or the license accompanying this file. Do not * remove or modify any license notices. This file is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * */ #include #include #include #include #include #include #include #include #include #include #include #include namespace EMotionFX { class BlendTreeTestInputNode : public AnimGraphNode { public: AZ_RTTI(AnimGraphBindPoseNode, "{72595B5C-045C-4DB1-88A4-40BC4560D7AF}", AnimGraphNode) enum { OUTPUTPORT_RESULT = 0 }; BlendTreeTestInputNode(float value) : AnimGraphNode() , m_identificationValue(value) { InitOutputPorts(1); SetupOutputPortAsPose("Output Pose", OUTPUTPORT_RESULT, OUTPUTPORT_RESULT); } AZ::Color GetVisualColor() const override { return AZ::Color(1.0f, 1.0f, 0.0f, 1.0f); } bool GetHasOutputPose() const override { return true; } const char* GetPaletteName() const override { return "BlendTreeTestInputNode"; } AnimGraphObject::ECategory GetPaletteCategory() const override { return AnimGraphObject::CATEGORY_SOURCES; } AnimGraphPose* GetMainOutputPose(AnimGraphInstance* animGraphInstance) const override { return GetOutputPose(animGraphInstance, OUTPUTPORT_RESULT)->GetValue(); } bool InitAfterLoading(AnimGraph* animGraph) override { if (!AnimGraphNode::InitAfterLoading(animGraph)) { return false; } InitInternalAttributesForAllInstances(); Reinit(); return true; } void Output(AnimGraphInstance* animGraphInstance) override { RequestPoses(animGraphInstance); AnimGraphPose* outputAnimGraphPose = GetOutputPose(animGraphInstance, OUTPUTPORT_RESULT)->GetValue(); outputAnimGraphPose->InitFromBindPose(animGraphInstance->GetActorInstance()); Pose& outputPose = outputAnimGraphPose->GetPose(); // Output the assigned value of the node for each joint so that we can identify from which input each joint is coming from. const AZ::u32 numJoints = outputPose.GetNumTransforms(); for (AZ::u32 i = 0; i < numJoints; ++i) { Transform transform = outputPose.GetLocalSpaceTransform(i); transform.mPosition = AZ::Vector3(m_identificationValue, m_identificationValue, m_identificationValue); outputPose.SetLocalSpaceTransform(i, transform); } } private: float m_identificationValue; }; using MaskNodeTestParam = std::vector>; /* * The general idea is to identify the origin of the joints by embedding identification values into the joint transform * and inside the test extract that value and thus know from which mask input it belongs to. * We create a blend tree with a mask node having several input nodes. The first one representing the base pose and three * input mask nodes with a customizable mask which comes in by the test parameter. * We run several tests with different variations of masks and check if the output transforms for each joint corresponds with * the set masks and if the mask node picked and overwrote the correct transforms. */ class BlendTreeMaskNodeTestFixture : public AnimGraphFixture , public ::testing::WithParamInterface { public: void ConstructActor() override { m_actor = ActorFactory::CreateAndInit(5); } AZStd::vector ConstructMask(const std::vector& in) { AZStd::vector result; result.reserve(in.size()); for (const std::string& str : in) { result.emplace_back(AZStd::string(str.c_str(), str.size())); } return result; } AZ::Outcome FindMaskIndexForJoint(AZ::u32 jointIndex) const { const MaskNodeTestParam& param = GetParam(); Skeleton* skeleton = m_actor->GetSkeleton(); const size_t numMasks = param.size(); for (size_t maskIndex = 0; maskIndex < numMasks; ++maskIndex) { const std::vector& mask = param[maskIndex]; const Node* joint = skeleton->GetNode(jointIndex); const char* jointName = joint->GetName(); // Is joint in the current mask? Return the index in this case. if (std::find(mask.begin(), mask.end(), jointName) != mask.end()) { return AZ::Success(maskIndex); } } return AZ::Failure(); } void ConstructGraph() override { AnimGraphFixture::ConstructGraph(); const MaskNodeTestParam& param = GetParam(); m_blendTreeAnimGraph = AnimGraphFactory::Create(); m_rootStateMachine = m_blendTreeAnimGraph->GetRootStateMachine(); m_blendTree = m_blendTreeAnimGraph->GetBlendTreeNode(); /* +-----------+ | Base Pose +----------+ +-----------+ | | +----------+ >+-----------+ +-------+ | Mask 0 +----------->| Pose Mask +-------------->+ Final | +----------+ ------>| | +-------+ | >+-----------+ +----------+ | | | Mask 1 +-----+ | +----------+ | | +-------------+ | | Mask 3 +--------+ +-------------+ */ m_maskNode = aznew BlendTreeMaskNode(); m_blendTree->AddChildNode(m_maskNode); BlendTreeFinalNode* finalNode = aznew BlendTreeFinalNode(); m_blendTree->AddChildNode(finalNode); finalNode->AddConnection(m_maskNode, BlendTreeMaskNode::OUTPUTPORT_RESULT, BlendTreeFinalNode::PORTID_INPUT_POSE); m_basePoseNode = aznew BlendTreeTestInputNode(static_cast(m_basePosePosValue)); m_blendTree->AddChildNode(m_basePoseNode); m_maskNode->AddConnection(m_basePoseNode, BlendTreeTestInputNode::OUTPUTPORT_RESULT, BlendTreeMaskNode::INPUTPORT_BASEPOSE); for (AZ::u32 i = 0; i < m_numMaskInputNodes; ++i) { BlendTreeTestInputNode* inputNode = aznew BlendTreeTestInputNode(static_cast(i)); m_blendTree->AddChildNode(inputNode); m_maskNode->AddConnection(inputNode, BlendTreeTestInputNode::OUTPUTPORT_RESULT, BlendTreeMaskNode::INPUTPORT_START + i); m_maskInputNodes.push_back(inputNode); } const size_t numMasks = param.size(); ASSERT_EQ(numMasks, m_numMaskInputNodes) << "The number of provides masks in the parameter (" << numMasks << ") should match the number of created " << "input mask nodes (" << m_numMaskInputNodes << ")."; for (size_t i = 0; i < numMasks; ++i) { m_maskNode->SetMask(i, ConstructMask(param[i])); } m_blendTreeAnimGraph->InitAfterLoading(); } void SetUp() override { AnimGraphFixture::SetUp(); m_animGraphInstance->Destroy(); m_animGraphInstance = m_blendTreeAnimGraph->GetAnimGraphInstance(m_actorInstance, m_motionSet); } public: AZStd::unique_ptr m_blendTreeAnimGraph; BlendTreeMaskNode* m_maskNode = nullptr; BlendTreeTestInputNode* m_basePoseNode = nullptr; const size_t m_basePosePosValue = 100; // Special identification value for the base pose to easily distinguish it from the mask indices. std::vector m_maskInputNodes; size_t m_numMaskInputNodes = 3; BlendTree* m_blendTree = nullptr; }; TEST_P(BlendTreeMaskNodeTestFixture, MaskTests) { GetEMotionFX().Update(0.0f); Skeleton* skeleton = m_actor->GetSkeleton(); const AZ::u32 numJoints = skeleton->GetNumNodes(); TransformData* transformData = m_actorInstance->GetTransformData(); Pose* pose = transformData->GetCurrentPose(); // Iterate through the joints and make sure their transforms originate according to the mask setup. for (AZ::u32 jointIndex = 0; jointIndex < numJoints; jointIndex++) { const Node* joint = skeleton->GetNode(jointIndex); const char* jointName = joint->GetName(); const Transform& transform = pose->GetModelSpaceTransform(jointIndex); // The components of the position embed the origin. // If the compareValue equals m_basePosePosValue, it originates from the base pose input. // In case the joint is part of any of the masks and got overwriten by them, the compareValue represents the mask index. const size_t compareValue = static_cast(transform.mPosition.GetX()); AZ::Outcome maskIndex = FindMaskIndexForJoint(jointIndex); if (maskIndex.IsSuccess()) { EXPECT_EQ(compareValue, maskIndex.GetValue()) << "Joint '" << jointName << "' is part of mask " << maskIndex.GetValue() << " while the transform originated from input number " << compareValue << "."; } else { EXPECT_EQ(compareValue, m_basePosePosValue) << "Joint '" << jointName << "' is not part of any mask while the transform " << "originated from input number " << compareValue << ". It should originate " << "from the base pose input."; } } } std::vector maskNodeTestData { { {}, {}, {}, }, { { "rootJoint" }, {}, {}, }, { { "rootJoint", "joint2" }, {}, {}, }, { { "rootJoint", "joint1", "joint2" }, {}, {}, }, { { "rootJoint", "joint1", "joint2", "joint3", "joint4" }, {}, {}, }, { {}, { "joint1", "joint3" }, {}, }, { {}, {}, { "joint2", "joint4" }, }, { { "rootJoint", "joint1" }, { "joint3", "joint4" }, {}, }, { { "rootJoint", "joint1" }, {}, { "joint3", "joint4" }, }, { {}, { "rootJoint", "joint1" }, { "joint3", "joint4" }, }, { { "rootJoint" }, { "joint1", "joint2" }, { "joint3", "joint4" }, }, }; INSTANTIATE_TEST_CASE_P(BlendTreeMaskNode, BlendTreeMaskNodeTestFixture, ::testing::ValuesIn(maskNodeTestData)); } // namespace EMotionFX