/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include "./utils.h" namespace tvm { namespace tir { /**************** Constructor ****************/ BlockRV::BlockRV() { this->data_ = make_object(); } LoopRV::LoopRV() { this->data_ = make_object(); } /**************** GetSRef ****************/ StmtSRef ScheduleNode::GetSRef(const StmtNode* stmt) const { ScheduleState state = this->state(); auto it = state->stmt2ref.find(stmt); if (it == state->stmt2ref.end()) { LOG(FATAL) << "IndexError: The stmt doesn't exist in the IR"; } return it->second; } /**************** FFI ****************/ TVM_REGISTER_NODE_TYPE(BlockRVNode); TVM_REGISTER_NODE_TYPE(LoopRVNode); TVM_REGISTER_OBJECT_TYPE(ScheduleNode); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") // .set_body_method(&ScheduleNode::mod); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // .set_body_method(&ScheduleNode::state); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // .set_body_method(&ScheduleNode::trace); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // .set_body_method(&ScheduleNode::Copy); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // .set_body_method(&ScheduleNode::Seed); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") // .set_body_method(&ScheduleNode::ForkSeed); /**************** (FFI) Constructor ****************/ TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, int error_render_level) -> Schedule { return Schedule::Concrete(mod, debug_mask, seed, static_cast(error_render_level)); }); TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, int error_render_level) -> Schedule { return Schedule::Traced(mod, seed, debug_mask, static_cast(error_render_level)); }); /******** (FFI) Lookup random variables ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet") .set_body_typed([](Schedule self, ObjectRef obj) -> ObjectRef { if (const auto* loop_rv = obj.as()) { return self->Get(GetRef(loop_rv)); } if (const auto* block_rv = obj.as()) { return self->Get(GetRef(block_rv)); } if (const auto* expr_rv = obj.as()) { return self->Get(GetRef(expr_rv)); } LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << obj->GetTypeKey() << ". Its value is: " << obj; throw; }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSRef") .set_body_typed([](Schedule self, ObjectRef obj) -> Optional { if (const auto* loop_rv = obj.as()) { return self->GetSRef(GetRef(loop_rv)); } if (const auto* block_rv = obj.as()) { return self->GetSRef(GetRef(block_rv)); } if (const auto* stmt = obj.as()) { return self->GetSRef(GetRef(stmt)); } LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); throw; }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") .set_body_typed([](Schedule self, ObjectRef obj) -> void { if (const auto* loop_rv = obj.as()) { return self->RemoveRV(GetRef(loop_rv)); } if (const auto* block_rv = obj.as()) { return self->RemoveRV(GetRef(block_rv)); } if (const auto* expr_rv = obj.as()) { return self->RemoveRV(GetRef(expr_rv)); } LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); throw; }); /******** (FFI) Sampling ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSamplePerfectTile") .set_body_method(&ScheduleNode::SamplePerfectTile); /******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") .set_body_typed([](Schedule self, ObjectRef rv) { if (const auto* block_rv = rv.as()) { return self->GetChildBlocks(GetRef(block_rv)); } if (const auto* loop_rv = rv.as()) { return self->GetChildBlocks(GetRef(loop_rv)); } LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() << ". Its value is: " << rv; throw; }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers") .set_body_method(&ScheduleNode::GetProducers); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers") .set_body_method(&ScheduleNode::GetConsumers); /******** (FFI) Transform loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder") .set_body_method(&ScheduleNode::Reorder); /******** (FFI) Manipulate ForKind ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel") .set_body_method(&ScheduleNode::Parallel); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleVectorize") .set_body_method(&ScheduleNode::Vectorize); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBind").set_body_method(&ScheduleNode::Bind); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnroll").set_body_method(&ScheduleNode::Unroll); /******** (FFI) Insert cache stages ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") .set_body_method(&ScheduleNode::ComputeAt); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeAt") .set_body_method(&ScheduleNode::ReverseComputeAt); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_method(&ScheduleNode::ComputeInline); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") .set_body_method(&ScheduleNode::ReverseComputeInline); /******** (FFI) Reduction ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposeReduction") .set_body_method(&ScheduleNode::DecomposeReduction); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor") .set_body_method(&ScheduleNode::RFactor); /******** (FFI) Block annotation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") .set_body_method(&ScheduleNode::StorageAlign); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope") .set_body_method(&ScheduleNode::SetScope); /******** (FFI) Blockize & Tensorize ********/ /******** (FFI) Annotation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key, const ObjectRef& ann_val) { if (const auto* block_rv = rv.as()) { return self->Annotate(GetRef(block_rv), ann_key, ann_val); } if (const auto* loop_rv = rv.as()) { return self->Annotate(GetRef(loop_rv), ann_key, ann_val); } LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() << ". Its value is: " << rv; throw; }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") .set_body_typed([](Schedule self, ObjectRef rv, const String& ann_key) { if (const auto* block_rv = rv.as()) { return self->Unannotate(GetRef(block_rv), ann_key); } if (const auto* loop_rv = rv.as()) { return self->Unannotate(GetRef(loop_rv), ann_key); } LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() << ". Its value is: " << rv; throw; }); /******** (FFI) Misc ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); } // namespace tir } // namespace tvm