# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import print_function import sys import os.path, re, StringIO blacklist = [ "Windows.h", "mach/clock.h", "mach/mach.h", "malloc.h", "glog/logging.h", "io/azure_filesys.h", "io/hdfs_filesys.h", "io/s3_filesys.h", "sys/stat.h", "sys/types.h", "omp.h", "execinfo.h", "packet/sse-inl.h", ] def get_sources(def_file): sources = [] files = [] visited = set() mxnet_path = os.path.abspath( os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir) ) for line in open(def_file): files = files + line.strip().split(" ") for f in files: f = f.strip() if not f or f.endswith(".o:") or f == "\\": continue fn = os.path.relpath(f) if os.path.abspath(f).startswith(mxnet_path) and fn not in visited: sources.append(fn) visited.add(fn) return sources sources = get_sources(sys.argv[1]) def find_source(name, start): candidates = [] for x in sources: if x == name or x.endswith("/" + name): candidates.append(x) if not candidates: return "" if len(candidates) == 1: return candidates[0] for x in candidates: if x.split("/")[1] == start.split("/")[1]: return x return "" re1 = re.compile("<([./a-zA-Z0-9_-]*)>") re2 = re.compile('"([./a-zA-Z0-9_-]*)"') sysheaders = [] history = set([]) out = StringIO.StringIO() def expand(x, pending): if x in history and x not in ["mshadow/mshadow/expr_scalar-inl.h"]: # MULTIPLE includes return if x in pending: # print('loop found: %s in ' % x, pending) return print("//===== EXPANDING: %s =====\n" % x, file=out) for line in open(x): if line.find("#include") < 0: out.write(line) continue if line.strip().find("#include") > 0: print(line) continue m = re1.search(line) if not m: m = re2.search(line) if not m: print(line + " not found") continue h = m.groups()[0].strip("./") source = find_source(h, x) if not source: if h not in blacklist and h not in sysheaders and "mkl" not in h and "nnpack" not in h: sysheaders.append(h) else: expand(source, pending + [x]) print("//===== EXPANDED: %s =====\n" % x, file=out) history.add(x) expand(sys.argv[2], []) f = open(sys.argv[3], "wb") for k in sorted(sysheaders): print("#include <%s>" % k, file=f) print("", file=f) print(out.getvalue(), file=f) for x in sources: if x not in history and not x.endswith(".o"): print("Not processed:", x)