Skip to content

Commit

Permalink
Merge pull request #1648 from freechipsproject/heap-bound
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkoenig authored Jun 10, 2020
2 parents 1769f8d + f7ae514 commit a7fe69b
Show file tree
Hide file tree
Showing 5 changed files with 322 additions and 71 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ spec/spec.out
spec/spec.synctex.gz
notes/*.docx
test_run_dir
__pycache__

.project

Expand Down
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,4 @@ jobs:
use: firrtl_build
script:
- benchmark/scripts/benchmark_cold_compile.py -N 2 --designs regress/ICache.fir --versions HEAD
- benchmark/scripts/find_heap_bound.py -- -cp firrtl*jar firrtl.stage.FirrtlMain -i regress/ICache.fir -o out -X verilog
80 changes: 9 additions & 71 deletions benchmark/scripts/benchmark_cold_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# See LICENSE for license details.

import subprocess
import re
from statistics import median, stdev
import sys
import argparse
from collections import OrderedDict
import os
import numbers

from monitor_job import monitor_job

# Currently hardcoded
def get_firrtl_repo():
cmd = ['git', 'rev-parse', '--show-toplevel']
Expand All @@ -19,65 +19,13 @@ def get_firrtl_repo():

firrtl_repo = get_firrtl_repo()

platform = ""
if sys.platform == 'darwin':
print("Running on MacOS")
platform = 'macos'
elif sys.platform.startswith("linux"):
print("Running on Linux")
platform = 'linux'
else :
raise Exception('Unrecognized platform ' + sys.platform)

def time():
if platform == 'macos':
return ['/usr/bin/time', '-l']
if platform == 'linux':
return ['/usr/bin/time', '-v']

def extract_max_size(output):
regex = ''
if platform == 'macos':
regex = '(\d+)\s+maximum resident set size'
if platform == 'linux':
regex = 'Maximum resident set size[^:]*:\s+(\d+)'

m = re.search(regex, output, re.MULTILINE)
if m :
return int(m.group(1))
else :
raise Exception('Max set size not found!')

def extract_run_time(output):
regex = ''
res = None
if platform == 'macos':
regex = '(\d+\.\d+)\s+real'
if platform == 'linux':
regex = 'Elapsed \(wall clock\) time \(h:mm:ss or m:ss\): ([0-9:.]+)'
m = re.search(regex, output, re.MULTILINE)
if m :
text = m.group(1)
if platform == 'macos':
return float(text)
if platform == 'linux':
parts = text.split(':')
if len(parts) == 3:
return float(parts[0]) * 3600 + float(parts[1]) * 60 + float(parts[0])
if len(parts) == 2:
return float(parts[0]) * 60 + float(parts[1])
raise Exception('Runtime not found!')

def run_firrtl(java, jar, design):
java_cmd = java.split()
cmd = time() + java_cmd + ['-cp', jar, 'firrtl.stage.FirrtlMain', '-i', design,'-o','out.v','-X','verilog']
result = subprocess.run(cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
if result.returncode != 0 :
print(result.stdout)
print(result.stderr)
sys.exit(1)
size = extract_max_size(result.stderr.decode('utf-8'))
runtime = extract_run_time(result.stderr.decode('utf-8'))
cmd = java_cmd + ['-cp', jar, 'firrtl.stage.FirrtlMain', '-i', design,'-o','out.v','-X','verilog']
print(' '.join(cmd))
resource_use = monitor_job(cmd)
size = resource_use.maxrss // 1024 # KiB -> MiB
runtime = resource_use.wall_clock_time
return (size, runtime)

def parseargs():
Expand Down Expand Up @@ -138,15 +86,6 @@ def check_designs(designs):
for design in designs:
assert os.path.exists(design), '{} must be an existing file!'.format(design)

# /usr/bin/time -v on Linux returns size in kbytes
# /usr/bin/time -l on MacOS returns size in Bytes
def norm_max_set_sizes(sizes):
div = None
if platform == 'linux':
d = 1000.0
if platform == 'macos':
d = 1000000.0
return [s / d for s in sizes]

def main():
args = parseargs()
Expand All @@ -156,7 +95,7 @@ def main():
jars = build_firrtl_jars(hashes)
jvms = args.jvms
N = args.iterations
info = [['java', 'revision', 'design', 'max heap', 'SD', 'runtime', 'SD']]
info = [['java', 'revision', 'design', 'max heap (MiB)', 'SD', 'runtime (s)', 'SD']]
for java in jvms:
print("Running with '{}'".format(java))
for hashcode, jar in jars.items():
Expand All @@ -166,8 +105,7 @@ def main():
for design in designs:
print('Running {}...'.format(design))
(sizes, runtimes) = zip(*[run_firrtl(java, jar, design) for i in range(N)])
norm_sizes = norm_max_set_sizes(sizes)
info.append([java_title, revision, design, median(norm_sizes), stdev(norm_sizes), median(runtimes), stdev(runtimes)])
info.append([java_title, revision, design, median(sizes), stdev(sizes), median(runtimes), stdev(runtimes)])
java_title = ''
revision = ''

Expand Down
188 changes: 188 additions & 0 deletions benchmark/scripts/find_heap_bound.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#!/usr/bin/env python3
# See LICENSE for license details.

import re
import argparse
from typing import NamedTuple
from subprocess import TimeoutExpired
import logging

from monitor_job import monitor_job, JobFailedError

BaseHeapSize = NamedTuple('JavaHeapSize', [('value', int), ('suffix', str)])
class HeapSize(BaseHeapSize):
K_FACTOR = 1024
M_FACTOR = 1024*1024
G_FACTOR = 1024*1024*1024

def toBytes(self) -> int:
return {
"": 1,
"K": self.K_FACTOR,
"M": self.M_FACTOR,
"G": self.G_FACTOR
}[self.suffix] * self.value

def round_to(self, target):
"""Round to positive multiple of target, only if we have the same suffix"""
if self.suffix == target.suffix:
me = self.toBytes()
tar = target.toBytes()
res = tar * round(me / tar)
if res == 0:
res = tar
return HeapSize.from_bytes(res)
else:
return self

def __truediv__(self, div):
b = self.toBytes()
res = int(b / div)
return HeapSize.from_bytes(res)

def __mul__(self, m):
b = self.toBytes()
res = int(b * m)
return HeapSize.from_bytes(res)

def __add__(self, rhs):
return HeapSize.from_bytes(self.toBytes() + rhs.toBytes())

def __sub__(self, rhs):
return HeapSize.from_bytes(self.toBytes() - rhs.toBytes())

@classmethod
def from_str(cls, s: str):
regex = '(\d+)([kKmMgG])?'
m = re.match(regex, s)
if m:
suffix = m.group(2)
if suffix is None:
return HeapSize(int(m.group(1)), "")
else:
return HeapSize(int(m.group(1)), suffix.upper())
else:
msg = "Invalid Heap Size '{}'! Format should be: '{}'".format(s, regex)
raise Exception(msg)

@classmethod
def from_bytes(cls, b: int):
if b % cls.G_FACTOR == 0:
return HeapSize(round(b / cls.G_FACTOR), "G")
if b % cls.M_FACTOR == 0:
return HeapSize(round(b / cls.M_FACTOR), "M")
if b % cls.K_FACTOR == 0:
return HeapSize(round(b / cls.K_FACTOR), "K")
return HeapSize(round(b), "")


def __str__(self):
return str(self.value) + self.suffix


def parseargs():
parser = argparse.ArgumentParser(
prog="find_heap_bound.py",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--verbose", "-v", action="count", default=0,
help="Increase verbosity level (cumulative)")
parser.add_argument("args", type=str, nargs="+",
help="Arguments to JVM, include classpath and main")
parser.add_argument("--start-size", type=str, default="4G",
help="Starting heap size")
parser.add_argument("--min-step", type=str, default="100M",
help="Minimum heap size step")
parser.add_argument("--java", type=str, default="java",
help="Java executable to use")
parser.add_argument("--timeout-factor", type=float, default=4.0,
help="Multiple of wallclock time of first successful run "
"that counts as a timeout, runs over this time count as a fail")
return parser.parse_args()


def get_logger(args):
logger = logging.getLogger("find_heap_bound")
if args.verbose == 1:
#logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO, format='%(message)s')
elif args.verbose >= 2:
logging.basicConfig(level=logging.DEBUG, format='%(message)s')
return logger


def mk_cmd(java, heap, args):
return [java, "-Xmx{}".format(heap)] + args


def job_failed_msg(e):
if isinstance(e, JobFailedError):
if "java.lang.OutOfMemoryError" in str(e):
return "Job failed, out of memory"
else:
return "Unexpected job failure\n{}".format(e)
elif isinstance(e, TimeoutExpired):
return "Job timed out at {} seconds".format(e.timeout)
else:
raise e


def main():
args = parseargs()
logger = get_logger(args)

results = []

min_step = HeapSize.from_str(args.min_step)
step = None
seen = set()
timeout = None # Set by first successful run
cur = HeapSize.from_str(args.start_size)
while cur not in seen:
seen.add(cur)
try:
cmd = mk_cmd(args.java, cur, args.args)
logger.info("Running {}".format(" ".join(cmd)))
stats = monitor_job(cmd, timeout=timeout)
logger.debug(stats)
if timeout is None:
timeout = stats.wall_clock_time * args.timeout_factor
logger.debug("Timeout set to {} s".format(timeout))
results.append((cur, stats))
if step is None:
step = (cur / 2).round_to(min_step)
else:
step = (step / 2).round_to(min_step)
cur = (cur - step).round_to(min_step)
except (JobFailedError, TimeoutExpired) as e:
logger.debug(job_failed_msg(e))
results.append((cur, None))
if step is None:
# Don't set step because we don't want to keep decreasing it
# when we haven't had a passing run yet
amt = (cur * 2).round_to(min_step)
else:
step = (step / 2).round_to(min_step)
amt = step
cur = (cur + step).round_to(min_step)
logger.debug("Next = {}, step = {}".format(cur, step))

sorted_results = sorted(results, key=lambda tup: tup[0].toBytes(), reverse=True)

table = [["Xmx", "Max RSS (MiB)", "Wall Clock (s)", "User Time (s)", "System Time (s)"]]
for heap, resources in sorted_results:
line = [str(heap)]
if resources is None:
line.extend(["-"]*4)
else:
line.append(str(resources.maxrss // 1024))
line.append(str(resources.wall_clock_time))
line.append(str(resources.user_time))
line.append(str(resources.system_time))
table.append(line)

csv = "\n".join([",".join(row) for row in table])
print(csv)


if __name__ == "__main__":
main()
Loading

0 comments on commit a7fe69b

Please sign in to comment.