/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include "stripe_config.h" #include #include #include #include #include #include #include #include #include "common.h" namespace tvm { namespace contrib { namespace ethosu { namespace cascader { template std::map, int> MultiplyCombinations(std::vector> values) { if (values.size() == 1) { std::map, int> combs; for (const auto& it : values[0]) { combs[std::vector(1, it.first)] = it.second; } return combs; } auto combs = MultiplyCombinations(std::vector>(values.begin(), values.end() - 1)); std::map, int> new_combs; for (const auto& val_it : values.back()) { for (const auto& comb_it : combs) { auto new_comb = std::vector(comb_it.first); new_comb.push_back(val_it.first); new_combs[new_comb] = val_it.second * comb_it.second; } } return new_combs; } std::map, int> CountStripes(const StripeConfig& stripe_config, bool enable_sliding_window = false) { std::vector> per_axis_sizes(stripe_config->GetOrder().size()); for (size_t axis = 0; axis < stripe_config->GetOrder().size(); axis++) { int start = stripe_config->GetOffset()[axis]; size_t stripe_count = static_cast(stripe_config->GetStripes()[axis]); int stride = stripe_config->GetStrides()[axis]; int shape = stripe_config->GetShape()[axis]; int extent = stripe_config->GetExtent()[axis]; int low; int high = std::numeric_limits::min(); for (size_t i = 0; i < stripe_count; i++) { // Calculate the 'non-edge case' sizes in one go to save effort if (!enable_sliding_window || i > 0) { if (start >= 0 && extent - shape - start >= 0 && stride > 0) { int whole_stripes = std::min(static_cast(stripe_count - i), (extent - shape - start) / stride + 1); if (enable_sliding_window) { per_axis_sizes[axis][stride] += whole_stripes; } else { per_axis_sizes[axis][shape] += whole_stripes; } i += whole_stripes - 1; start += whole_stripes * stride; high = std::min(start - stride + shape, extent); continue; } } low = std::max(start, 0); if (enable_sliding_window) { low = std::max(low, high); } high = std::min(start + shape, extent); int size = high - low; if (size > 0) { per_axis_sizes[axis][size]++; } start += stride; } } return MultiplyCombinations(per_axis_sizes); } TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.CountStripes") .set_body_typed([](StripeConfig stripe_config, bool enable_sliding_window) { Map, Integer> ret; auto stripe_counts = CountStripes(stripe_config, enable_sliding_window); for (const auto& it : stripe_counts) { ret.Set(make_array(it.first), it.second); } return ret; }); void StripeConfigNode::VisitAttrs(AttrVisitor* v) { Array tmp_arr = make_array(shape_); v->Visit("_shape", &tmp_arr); tmp_arr = make_array(extent_); v->Visit("_extent", &tmp_arr); tmp_arr = make_array(order_); v->Visit("_order", &tmp_arr); tmp_arr = make_array(stripes_); v->Visit("_stripes", &tmp_arr); tmp_arr = make_array(offset_); v->Visit("_offset", &tmp_arr); Array tmp_float_arr = make_array(strides_); v->Visit("_strides", &tmp_float_arr); int64_t tmp_hash = static_cast(hash_); v->Visit("_hash", &tmp_hash); } void StripeConfigNode::ComputeHash_() { hash_ = hash_vector(shape_); hash_combine(&hash_, hash_vector(extent_)); hash_combine(&hash_, hash_vector(strides_)); hash_combine(&hash_, hash_vector(order_)); hash_combine(&hash_, hash_vector(stripes_)); hash_combine(&hash_, hash_vector(offset_)); } StripeConfig::StripeConfig(const std::vector& shape, const std::vector& extent, const std::vector& strides, const std::vector& order, const std::vector& stripes, const std::vector& offset) { auto n = make_object(); n->shape_ = std::move(shape); n->extent_ = std::move(extent); n->strides_ = std::move(strides); n->order_ = std::move(order); n->stripes_ = std::move(stripes); n->offset_ = std::move(offset); n->ComputeHash_(); data_ = std::move(n); } inline bool StripeConfig::operator==(const StripeConfig& other) const { if (get() == other.get()) return true; if (get() == nullptr || other.get() == nullptr) return false; return ((*this)->shape_ == other->shape_ && (*this)->extent_ == other->extent_ && (*this)->strides_ == other->strides_ && (*this)->order_ == other->order_ && (*this)->stripes_ == other->stripes_ && (*this)->offset_ == other->offset_); } TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.StripeConfig") .set_body_typed([](Array shape, Array extent, Array strides, Array order, Array stripes, Array offset) { std::vector vshape = make_vector(shape); std::vector vextent = make_vector(extent); std::vector vstrides = make_vector(strides); std::vector vorder = make_vector(order); std::vector vstripes = make_vector(stripes); std::vector voffset = make_vector(offset); return StripeConfig(vshape, vextent, vstrides, vorder, vstripes, voffset); }); TVM_REGISTER_GLOBAL("contrib.ethosu.cascader.StripeConfigEqual") .set_body_method(&StripeConfig::operator==); TVM_REGISTER_NODE_TYPE(StripeConfigNode); } // namespace cascader } // namespace ethosu } // namespace contrib } // namespace tvm