/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file texture_flatten.cc * \brief Flattens texture storage from multi-dimensional array * to 2D (width, height) buffer access */ #include #include #include #include #include #include #include #include "../../arith/ir_visitor_with_analyzer.h" #include "../../runtime/texture.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { namespace tir { using runtime::ApplyTexture2DFlattening; using runtime::DefaultTextureLayoutSeparator; using runtime::IsTextureStorage; class TextureLoweringBase : public StmtExprMutator { public: explicit TextureLoweringBase(const Map& extern_buffer_map, IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_{bound_analyzer} { for (auto kv : extern_buffer_map) { extern_buf_.insert(kv.second); } } inline PrimExpr SimplifyOffset(const Array& shape, const Array& index) const { PrimExpr base = make_const(DataType::Int(32), 0); ICHECK_EQ(shape.size(), index.size()); if (index.size() > 0) { PrimExpr offset = index[0]; for (size_t i = 1; i < index.size(); ++i) { offset = bound_analyzer_->Simplify(offset * shape[i] + index[i]); } base = base + offset; } return base; } protected: std::string GetStorageScope(const Buffer& buffer) { auto* ptr = buffer->data->type_annotation.as(); ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; return ptr->storage_scope; } // Set of all external input and output buffers std::unordered_set extern_buf_; // Bound analzer IRVisitorWithAnalyzer* bound_analyzer_; }; // Lower Nd storage access to 2d texture access using lowering convention // specified by the buffers storage scope. class TextureFlattener : public TextureLoweringBase { public: using StmtExprMutator::VisitStmt_; explicit TextureFlattener(const Map& extern_buffer_map, IRVisitorWithAnalyzer* bound_analyzer) : TextureLoweringBase(extern_buffer_map, bound_analyzer) {} Stmt VisitStmt_(const BufferRealizeNode* op) final { if (extern_buf_.count(op->buffer)) { return this->VisitStmt(op->body); } std::string storage_scope = GetStorageScope(op->buffer); Var buffer_var(op->buffer->data->name_hint, PointerType(PrimType(op->buffer->dtype), String(storage_scope))); let_binding_.insert({op->buffer->data, buffer_var}); Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); // Rewrite any buffer realizations with storage scope to 2d texture allocations if (IsTextureStorage(storage_scope)) { Stmt body = this->VisitStmt(op->body); ICHECK(op->bounds.size() >= 3) << "Only 2d RGBA texture is currently supported"; int vec_length = static_cast(op->bounds.back()->extent.as()->value); ICHECK(vec_length == 4 || vec_length == 1) << "Inner dimension of texture must be vector of length 1 or 4 (RGBA)"; struct ShapeFromRange { const Array& bounds; PrimExpr operator[](size_t i) const { return bounds[i]->extent; } }; size_t axis = DefaultTextureLayoutSeparator(op->bounds.size(), storage_scope); auto texture = ApplyTexture2DFlattening(ShapeFromRange{op->bounds}, op->bounds.size(), axis); Array args = {texture.width, texture.height}; stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::texture2d_alloca(), args), body); } return stmt; } Stmt VisitStmt_(const BufferStoreNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); std::string storage_scope = GetStorageScope(op->buffer); // Lower to two dimensional access if (IsTextureStorage(storage_scope)) { Array args = GetTextureAccessArgs(op, op->buffer); args.push_back(op->value); stmt = Evaluate(Call(args[0]->dtype, builtin::texture2d_store(), args)); } return stmt; } PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); // Lower to two dimensional access std::string storage_scope = GetStorageScope(op->buffer); if (IsTextureStorage(storage_scope)) { Array args = GetTextureAccessArgs(op, op->buffer); args.push_back(op->indices.back()); expr = Call(op->buffer->dtype, builtin::texture2d_load(), args); } return expr; } protected: template Array GetTextureAccessArgs(const T* op, const Buffer& buffer) { Array args; if (let_binding_.count(op->buffer->data)) { args.push_back(let_binding_[op->buffer->data]); } else { args.push_back(buffer->data); } Array row_dims, row_indices, col_dims, col_indices; for (size_t i = 0; i < op->buffer->shape.size() - 1; i++) { if (i < DefaultTextureLayoutSeparator(op->buffer->shape.size(), GetStorageScope(buffer))) { col_dims.push_back(op->buffer->shape[i]); col_indices.push_back(op->indices[i]); } else { row_dims.push_back(op->buffer->shape[i]); row_indices.push_back(op->indices[i]); } } PrimExpr row_offset = SimplifyOffset(row_dims, row_indices); PrimExpr col_offset = SimplifyOffset(col_dims, col_indices); args.push_back(row_offset); args.push_back(col_offset); return args; } // Bindings to new texture vars with texture pointer scope std::unordered_map let_binding_; }; PrimFunc TextureFlatten(PrimFunc func) { auto fptr = func.CopyOnWrite(); IRVisitorWithAnalyzer bound_analyzer; bound_analyzer(fptr->body); fptr->body = TextureFlattener(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); return func; } namespace transform { Pass TextureFlatten() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { return TextureFlatten(std::move(f)); }; return CreatePrimFuncPass(pass_func, 0, "tir.TextureFlatten", {}); } TVM_REGISTER_GLOBAL("tir.transform.TextureFlatten").set_body_typed(TextureFlatten); } // namespace transform } // namespace tir } // namespace tvm