/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include #include #include namespace tvm { namespace relay { namespace transform { namespace { /*! \brief Collect all attributes whose name contains "layout". */ struct CollectAttrs : public AttrVisitor { void Visit(const char* key, std::string* value) final { if (std::string(key).find("layout") != std::string::npos) { attrs[key] = String(*value); } } void Visit(const char* key, double* value) final {} void Visit(const char* key, uint64_t* value) final {} void Visit(const char* key, int* value) final {} void Visit(const char* key, int64_t* value) final {} void Visit(const char* key, bool* value) final {} void Visit(const char* key, runtime::NDArray* value) final {} void Visit(const char* key, ObjectRef* value) final { if (std::string(key).find("layout") != std::string::npos) { attrs[key] = *value; } } void Visit(const char* key, DataType* value) final {} void Visit(const char* key, void** value) final {} std::unordered_map attrs; }; } // namespace /*! \brief Visitor to add structural hash and layout information to `Function` * nodes. Sets the "hash" field on the attr to the structural hash of the * function. Propogates any attributes with "layout" in their name from call * nodes in the Function to the Function's attrs. */ class LabelOpsMutator : public MixedModeMutator { private: using MixedModeMutator::VisitExpr_; std::unordered_map body_attrs; Expr VisitExpr_(const FunctionNode* op) final { if (op->GetAttr("hash").defined()) { // Already labelled. return ExprMutator::VisitExpr_(op); } // body_attrs collects attrs from Calls in the body of this Function. Reset // it so we only get attrs from this Function. body_attrs = {}; auto updated = ExprMutator::VisitExpr_(op); size_t hash = StructuralHash()(updated); // format hash as fixed length hex string so it is easier to read std::stringstream s; s << std::setfill('0') << std::setw(sizeof(size_t) * 2) << std::hex << hash; Function f = WithAttr(Downcast(updated), "hash", String(s.str())); for (auto p : body_attrs) { f = WithAttr(f, p.first, p.second); } return std::move(f); } Expr VisitExpr_(const LetNode* op) final { auto pre_visit = [this](const LetNode* op) { this->Mutate(op->var); this->Mutate(op->value); }; auto post_visit = [this](const LetNode* op) { Var var = Downcast(this->Mutate(op->var)); auto value = this->Mutate(op->value); auto body = this->Mutate(op->body); auto expr = GetRef(op); if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { this->memo_[expr] = expr; } else { this->memo_[expr] = Let(var, value, body); } }; ExpandANormalForm(op, pre_visit, post_visit); return memo_[GetRef(op)]; } Expr Rewrite_(const CallNode* op, const Expr& post) final { auto updated = MixedModeMutator::Rewrite_(op, post); if (op->attrs.defined()) { CollectAttrs collect; const_cast(op->attrs.get())->VisitAttrs(&collect); for (auto p : collect.attrs) { if (body_attrs.find(p.first) != body_attrs.end() && p.second == body_attrs[p.first]) { LOG(WARNING) << "LabelOps found two call sites with different values for " << p.first << " (" << p.second << " vs " << body_attrs[p.first] << "). Only the first will be recorded."; } body_attrs[p.first] = p.second; } } return updated; } }; /*! \brief Add structural hash and layout information to Function nodes. This * information is used later by the profiler. * * The hash and layout information is added to the attrs field of the Function. * The key "hash" contains the structural hash of the node. Any attributes with * "layout" in their name are also added to attrs (for example, * `attrs["src_layout"]` contains the `src_layout` attribute of the TVM op * corresponding to this function call). */ Pass LabelOps() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(LabelOpsMutator().Mutate(f)); }; return CreateFunctionPass(pass_func, 1, "LabelOps", {}); } } // namespace transform } // namespace relay } // namespace tvm