/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file verify_memory.cc * \brief Pass to check if memory accesses are legal. */ #include #include #include #include #include #include #include namespace tvm { namespace tir { namespace { /*! * \brief Verify if memory accesses are legal. * * In the case that tgt is cuda, if workload is not bound with * threads, CPU code is generated that tries to access GPU memory, * which is illegal. * * This pass performs such verification by checking if all * memory accesses are bound with threads when device type is GPU. */ class MemoryAccessVerifier final : protected StmtExprVisitor { public: /// Special member functions //@{ explicit MemoryAccessVerifier(PrimFunc f, int device_type) : func_(f), dev_type_(device_type) {} virtual ~MemoryAccessVerifier() = default; MemoryAccessVerifier(const MemoryAccessVerifier&) = delete; MemoryAccessVerifier(MemoryAccessVerifier&&) = delete; MemoryAccessVerifier& operator=(const MemoryAccessVerifier&) = delete; MemoryAccessVerifier& operator=(MemoryAccessVerifier&&) = delete; //@} /// Interface to perform memory access verification void Run() { if (!IsGPUDevice(dev_type_) && !IsFPGADevice(dev_type_)) return; StmtExprVisitor::VisitStmt(func_->body); } /// Verification result std::vector Errors() const { return errs_; } protected: /// Visitor implementation //@{ void VisitExpr(const PrimExpr& n) final { StmtExprVisitor::VisitExpr(n); } void VisitStmt(const Stmt& n) final { StmtExprVisitor::VisitStmt(n); } void VisitStmt_(const LetStmtNode* op) final { // Book keep definitions defs_[op->var.get()] = op->value; return StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const AttrStmtNode* op) final { if (!InThreadEnv() && (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) { EnterThreadEnv(); StmtExprVisitor::VisitStmt_(op); ExitThreadEnv(); } else { StmtExprVisitor::VisitStmt_(op); } } void VisitExpr_(const LoadNode* op) final { HandleLoadStoreToVariable(op->buffer_var); return StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const StoreNode* op) final { HandleLoadStoreToVariable(op->buffer_var); return StmtExprVisitor::VisitStmt_(op); } //@} /// Check if the value of a Variable comes from function argument. bool IsFromFunctionArgs(const VarNode* var) const { const VarNode* V = var; for (auto kv : func_->buffer_map) { if (V == kv.second->data.get()) return true; } while (true) { // Variable is from function args. Return true. if (V == func_->params[0].get()) return true; // The value is expected to come from a tvm_struct_get Call. // Get the first argument of tvm_struct_get, and continue. const auto& iter = defs_.find(V); if (iter == defs_.end()) return false; const CallNode* C = iter->second.as(); if (!C || !C->op.same_as(builtin::tvm_struct_get())) return false; V = C->args[0].as(); } return false; } /// Handle memory access to a Variable void HandleLoadStoreToVariable(const Var& var) { // We skip the access within thread env. if (InThreadEnv()) return; // We only handle the variable from function argument. // If it does not come from args, then it could be allocated internally, // it may possibly be in host or device address space. // We do not handle this case, and skip it conservatively. if (!IsFromFunctionArgs(var.get())) return; // The verification fails in this case. std::stringstream s; s << "Variable `" << var << "` is directly accessed by host memory (it is not contained in a thread environment or in " "the function arguments."; errs_.push_back(s.str()); } /// Status getter/setter //@{ bool InThreadEnv() const { return in_thread_env_; } void EnterThreadEnv() { in_thread_env_ = true; } void ExitThreadEnv() { in_thread_env_ = false; } //@} /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device. static bool IsGPUDevice(int dev_type) { return kDLCUDA == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type || kDLMetal == dev_type || kDLROCM == dev_type || kOpenGL == dev_type; } /// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device. static bool IsFPGADevice(int dev_type) { return kDLSDAccel == dev_type || kDLAOCL == dev_type; } private: /// Status of visitor //@{ bool in_thread_env_{false}; std::vector errs_; //@} tir::PrimFunc func_{nullptr}; ///< Function to be verified. int dev_type_{kDLCPU}; ///< Device type std::unordered_map defs_; ///< Variable definitions }; } // namespace /// Interface of VerifyMemory pass std::vector VerifyMemory_(const PrimFunc& func) { auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; VLOG(1) << "verifying memory for target '" << target.value()->str() << "' for primitive:" << std::endl << PrettyPrint(func); if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { MemoryAccessVerifier v(func, target.value()->kind->device_type); v.Run(); return v.Errors(); } else { return {}; } } bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; } TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory); namespace transform { Pass VerifyMemory() { auto pass_func = [=](IRModule mod, PassContext ctx) { for (auto kv : mod->functions) { if (auto* n = kv.second.as()) { auto func = GetRef(n); auto errs = VerifyMemory_(func); if (errs.size() > 0) { std::stringstream s; for (auto& err : errs) { s << " " << err << "\n"; } LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n" << s.str() << " Did you forget to bind?\n" << func; } } } return mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {}); } TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory").set_body_typed(VerifyMemory); } // namespace transform } // namespace tir } // namespace tvm