/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include "../utils.h" namespace tvm { namespace tir { void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key, const ObjectRef& ann_val) { // Extract annotation const Map* annotations = nullptr; if (const auto* loop = sref->StmtAs()) { annotations = &loop->annotations; } else if (const auto* block = sref->StmtAs()) { annotations = &block->annotations; } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); } // Check if the annotation already exists if (annotations->find(ann_key) != annotations->end()) { return; } // Add the new annotation Map new_ann(*annotations); new_ann.Set(ann_key, ann_val); // Create the new stmt if (const auto* loop = sref->StmtAs()) { ObjectPtr n = make_object(*loop); n->annotations = std::move(new_ann); self->Replace(sref, For(n), {}); } else if (const auto* block = sref->StmtAs()) { ObjectPtr n = make_object(*block); n->annotations = std::move(new_ann); Block p(n); self->Replace(sref, p, {{GetRef(block), p}}); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; } } void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key) { // Extract annotation const Map* annotations = nullptr; if (const auto* loop = sref->StmtAs()) { annotations = &loop->annotations; } else if (const auto* block = sref->StmtAs()) { annotations = &block->annotations; } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); } // Remove the annotation ICHECK(annotations->find(ann_key) != annotations->end()) << "IndexError: Cannot find annotation key: " << ann_key; Map new_ann(*annotations); new_ann.erase(ann_key); // Create the new stmt if (const auto* loop = sref->StmtAs()) { ObjectPtr n = make_object(*loop); n->annotations = std::move(new_ann); self->Replace(sref, For(n), {}); } else if (const auto* block = sref->StmtAs()) { ObjectPtr n = make_object(*block); n->annotations = std::move(new_ann); Block p(n); self->Replace(sref, p, {{GetRef(block), p}}); } else { LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); throw; } } struct AnnotateTraits : public UnpackedInstTraits { static constexpr const char* kName = "Annotate"; static constexpr bool kIsPure = false; private: static constexpr size_t kNumInputs = 2; static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ObjectRef ann_val, String ann_key) { if (const auto* block = block_or_loop_rv.as()) { return sch->Annotate(GetRef(block), ann_key, ann_val); } if (const auto* loop = block_or_loop_rv.as()) { return sch->Annotate(GetRef(loop), ann_key, ann_val); } LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); throw; } static String UnpackedAsPython(Array outputs, ObjectRef block_or_loop_rv, ObjectRef ann_val, String ann_key) { PythonAPICall py("annotate"); py.Input("block_or_loop", block_or_loop_rv); py.Input("ann_key", ann_key); if (const auto* int_imm = ann_val.as()) { py.Input("ann_val", std::to_string(int_imm->value)); } else if (const auto* str_imm = ann_val.as()) { py.Input("ann_val", GetRef(str_imm)); } else if (const auto* expr = ann_val.as()) { std::ostringstream os; os << GetRef(expr); py.Input("ann_val", os.str()); } else { LOG(FATAL) << "TypeError: Cannot handle type: " << ann_val->GetTypeKey(); throw; } return py.Str(); } template friend struct ::tvm::tir::UnpackedInstTraits; }; struct UnannotateTraits : public UnpackedInstTraits { static constexpr const char* kName = "Unannotate"; static constexpr bool kIsPure = false; private: static constexpr size_t kNumInputs = 1; static constexpr size_t kNumAttrs = 1; static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String ann_key) { if (const auto* block = block_or_loop_rv.as()) { return sch->Unannotate(GetRef(block), ann_key); } if (const auto* loop = block_or_loop_rv.as()) { return sch->Unannotate(GetRef(loop), ann_key); } LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); throw; } static String UnpackedAsPython(Array outputs, ObjectRef block_or_loop_rv, String ann_key) { PythonAPICall py("unannotate"); py.Input("block_or_loop", block_or_loop_rv); py.Input("ann_key", ann_key); return py.Str(); } template friend struct ::tvm::tir::UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(AnnotateTraits); TVM_REGISTER_INST_KIND_TRAITS(UnannotateTraits); } // namespace tir } // namespace tvm