Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Commit

Permalink
Merge branch '1.5.x' into 1.5-release
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkoenig committed Jan 8, 2023
2 parents 26da1f6 + 01e567d commit b4054c2
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 46 deletions.
53 changes: 51 additions & 2 deletions benchmark/scripts/find_heap_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
from typing import NamedTuple
from subprocess import TimeoutExpired
import logging
from functools import total_ordering

from monitor_job import monitor_job, JobFailedError

BaseHeapSize = NamedTuple('JavaHeapSize', [('value', int), ('suffix', str)])

@total_ordering
class HeapSize(BaseHeapSize):
K_FACTOR = 1024
M_FACTOR = 1024*1024
Expand Down Expand Up @@ -51,6 +54,16 @@ def __add__(self, rhs):
def __sub__(self, rhs):
return HeapSize.from_bytes(self.toBytes() - rhs.toBytes())


def __eq__(self, rhs):
return self.toBytes() == rhs.toBytes()

# Defining __eq__ for total_ordering forces us to explicitly inherit __hash__
__hash__ = BaseHeapSize.__hash__

def __ge__(self, rhs):
return self.toBytes() >= rhs.toBytes()

@classmethod
def from_str(cls, s: str):
regex = '(\d+)([kKmMgG])?'
Expand Down Expand Up @@ -97,6 +110,8 @@ def parseargs():
parser.add_argument("--timeout-factor", type=float, default=4.0,
help="Multiple of wallclock time of first successful run "
"that counts as a timeout, runs over this time count as a fail")
parser.add_argument("--context", type=int, default=0,
help="Number of extra steps above the minimum bound to run")
return parser.parse_args()


Expand Down Expand Up @@ -137,16 +152,23 @@ def main():
seen = set()
timeout = None # Set by first successful run
cur = HeapSize.from_str(args.start_size)
while cur not in seen:
last_success = cur

# Do binary search
while cur not in seen and (step is None or step >= min_step):
seen.add(cur)
try:
cmd = mk_cmd(args.java, cur, args.args)
logger.info("Running {}".format(" ".join(cmd)))
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Running {}".format(" ".join(cmd)))
else:
logger.info("Running {}".format(cur))
stats = monitor_job(cmd, timeout=timeout)
logger.debug(stats)
if timeout is None:
timeout = stats.wall_clock_time * args.timeout_factor
logger.debug("Timeout set to {} s".format(timeout))
last_success = cur
results.append((cur, stats))
if step is None:
step = (cur / 2).round_to(min_step)
Expand All @@ -166,6 +188,33 @@ def main():
cur = (cur + step).round_to(min_step)
logger.debug("Next = {}, step = {}".format(cur, step))

# Run extra steps for some context above the minimum size
extra_steps = []
if args.context > 0:
for i in range(1, args.context):
diff = min_step * i
heap_size = last_success + diff
if heap_size not in seen:
extra_steps.append(heap_size)
log_steps = ", ".join([str(e) for e in extra_steps]) # Pretty print
logger.info("Because context is {}, running extra heap sizes: {}".format(args.context, log_steps))

for cur in extra_steps:
logger.debug("Next = {}".format(cur))
seen.add(cur)
try:
cmd = mk_cmd(args.java, cur, args.args)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Running {}".format(" ".join(cmd)))
else:
logger.info("Running {}".format(cur))
stats = monitor_job(cmd, timeout=timeout)
logger.debug(stats)
results.append((cur, stats))
except (JobFailedError, TimeoutExpired) as e:
logger.debug(job_failed_msg(e))
results.append((cur, None))

sorted_results = sorted(results, key=lambda tup: tup[0].toBytes(), reverse=True)

table = [["Xmx", "Max RSS (MiB)", "Wall Clock (s)", "User Time (s)", "System Time (s)"]]
Expand Down
8 changes: 6 additions & 2 deletions src/main/scala/firrtl/passes/CheckWidths.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,13 @@ object CheckWidths extends Pass {
// This is a leaf check of the "local" width-correctness of one expression node, so no recursion.
expr match {
case e @ UIntLiteral(v, w: IntWidth) if math.max(1, v.bitLength) > w.width =>
errors.append(new WidthTooSmall(info, target.serialize, v))
if (w.width > 0 || (w.width == 0 && v != 0)) { // UInt<0>(0) is allowed
errors.append(new WidthTooSmall(info, target.serialize, v))
}
case e @ SIntLiteral(v, w: IntWidth) if v.bitLength + 1 > w.width =>
errors.append(new WidthTooSmall(info, target.serialize, v))
if (w.width > 0 || (w.width == 0 && v != 0)) { // SInt<0>(0) is allowed
errors.append(new WidthTooSmall(info, target.serialize, v))
}
case e @ DoPrim(op, Seq(a, b), _, tpe) =>
(op, a.tpe, b.tpe) match {
case (Squeeze, IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _))
Expand Down
4 changes: 4 additions & 0 deletions src/test/scala/firrtl/testutils/LeanTransformSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class LeanTransformSpec(protected val transforms: Seq[TransformDependency])
actual should be(expected)
finalState
}
protected def removeSkip(c: ir.Circuit): ir.Circuit = {
def onStmt(s: ir.Statement): ir.Statement = s.mapStmt(onStmt)
c.mapModule(m => m.mapStmt(onStmt))
}
}

private object LeanTransformSpec {
Expand Down
89 changes: 47 additions & 42 deletions src/test/scala/firrtlTests/ZeroWidthTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,11 @@
package firrtlTests

import firrtl._
import firrtl.options.Dependency
import firrtl.passes._
import firrtl.testutils._

class ZeroWidthTests extends FirrtlFlatSpec {
def transforms = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, ZeroWidth)
private def exec(input: String) = {
val circuit = parse(input)
transforms
.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) =>
p.runTransform(c)
}
.circuit
.serialize
}
class ZeroWidthTests extends LeanTransformSpec(Seq(Dependency(ZeroWidth))) {
// =============================
"Zero width port" should " be deleted" in {
val input =
Expand All @@ -30,7 +21,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| module Top :
| output x : UInt<1>
| x <= UInt<1>(0)""".stripMargin
(parse(exec(input))) should be(parse(check))
compile(input).circuit.serialize should be(parse(check).serialize)
}
"Add of <0> and <2> " should " put in zero" in {
val input =
Expand All @@ -44,35 +35,36 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| module Top :
| output x : UInt<3>
| x <= add(UInt<1>(0), UInt<2>(2))""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
compile(input).circuit.serialize should be(parse(check).serialize)
}
"Mux on <0>" should "put in zero" in {
"Mux on <0>" should "not be allowed" in {
// Note that this used to be allowed, but the support seems to have bit-rotted
// and modern firrtl enforces 1-bit UInt for muxes.
val input =
"""circuit Top :
| module Top :
| input y : UInt<0>
| output x : UInt
| x <= mux(y, UInt<2>(2), UInt<2>(1))""".stripMargin
val check =
"""circuit Top :
| module Top :
| output x : UInt<2>
| x <= mux(UInt<1>(0), UInt<2>(2), UInt<2>(1))""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
val e = intercept[PassException] { compile(input) }
assert(e.getMessage.contains("A mux condition must be of type 1-bit UInt"))
}
"Bundle with field of <0>" should "get deleted" in {
val input =
"""circuit Top :
| module Top :
| input y : { a: UInt<0> }
| output x : { a: UInt<0>, b: UInt<1>}
| x.b <= UInt(1)
| x.a <= y.a""".stripMargin
val check =
"""circuit Top :
| module Top :
| output x : { b: UInt<1> }
| skip""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
| skip
| x.b <= UInt(1)
| """.stripMargin
compile(input).circuit.serialize should be(parse(check).serialize)
}
"Vector with type of <0>" should "get deleted" in {
val input =
Expand All @@ -85,7 +77,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
"""circuit Top :
| module Top :
| skip""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
removeSkip(compile(input).circuit).serialize should be(parse(check).serialize)
}
"Node with <0>" should "be removed" in {
val input =
Expand All @@ -97,7 +89,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
"""circuit Top :
| module Top :
| skip""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
compile(input).circuit.serialize should be(parse(check).serialize)
}
"IsInvalid on <0>" should "be deleted" in {
val input =
Expand All @@ -109,7 +101,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
"""circuit Top :
| module Top :
| skip""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
compile(input).circuit.serialize should be(parse(check).serialize)
}
"Expression in node with type <0>" should "be replaced by UInt<1>(0)" in {
val input =
Expand All @@ -123,7 +115,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| module Top :
| input x: UInt<1>
| node z = add(x, UInt<1>(0))""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
compile(input).circuit.serialize should be(parse(check).serialize)
}
"Expression in cat with type <0>" should "be removed" in {
val input =
Expand All @@ -137,7 +129,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| module Top :
| input x: UInt<1>
| node z = x""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
compile(input).circuit.serialize should be(parse(check).serialize)
}
"Nested cats with type <0>" should "be removed" in {
val input =
Expand All @@ -151,7 +143,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
"""circuit Top :
| module Top :
| skip""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
compile(input).circuit.serialize should be(parse(check).serialize)
}
"Nested cats where one has type <0>" should "be unaffected" in {
val input =
Expand All @@ -167,9 +159,11 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| input x: UInt<1>
| input z: UInt<1>
| node a = cat(x, z)""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
compile(input).circuit.serialize should be(parse(check).serialize)
}
"Stop with type <0>" should "be replaced with UInt(0)" in {
// Note that this used to be allowed, but the support seems to have bit-rotted
// and modern firrtl enforces 1-bit UInt for stop enables.
val input =
"""circuit Top :
| module Top :
Expand All @@ -178,14 +172,8 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| input y: UInt<0>
| input z: UInt<1>
| stop(clk, y, 1)""".stripMargin
val check =
"""circuit Top :
| module Top :
| input clk: Clock
| input x: UInt<1>
| input z: UInt<1>
| stop(clk, UInt(0), 1)""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
val e = intercept[PassException] { compile(input) }
assert(e.getMessage.contains("Enable must be a 1-bit UIntType typed signal"))
}
"Print with type <0>" should "be replaced with UInt(0)" in {
val input =
Expand All @@ -203,7 +191,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| input x: UInt<1>
| input z: UInt<1>
| printf(clk, UInt(1), "%d %d %d\n", x, UInt(0), z)""".stripMargin
(parse(exec(input)).serialize) should be(parse(check).serialize)
compile(input).circuit.serialize should be(parse(check).serialize)
}

"Andr of zero-width expression" should "return true" in {
Expand All @@ -218,7 +206,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| module Top :
| output x : UInt<1>
| x <= UInt<1>(1)""".stripMargin
(parse(exec(input))) should be(parse(check))
compile(input).circuit.serialize should be(parse(check).serialize)
}

"Cat of SInt with zero-width" should "keep type correctly" in {
Expand All @@ -235,7 +223,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| input y : SInt<1>
| output z : UInt<1>
| z <= asUInt(y)""".stripMargin
(parse(exec(input))) should be(parse(check))
compile(input).circuit.serialize should be(parse(check).serialize)
}

"dshl with zero-width" should "canonicalize to the un-shifted expression" in {
Expand All @@ -252,7 +240,7 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| input y : SInt<1>
| output z : SInt<1>
| z <= y""".stripMargin
(parse(exec(input))) should be(parse(check))
compile(input).circuit.serialize should be(parse(check).serialize)
}

"Memories with zero-width data-type" should "be fully removed" in {
Expand Down Expand Up @@ -315,7 +303,24 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| input rwMask: UInt<1>
|
|${Seq.tabulate(17)(_ => " skip").mkString("\n")}""".stripMargin
parse(exec(input)) should be(parse(check))
compile(input).circuit.serialize should be(parse(check).serialize)
}

"zero width literals" should "be permissible" in {
val input =
"""circuit Foo:
| module Foo:
| output x : UInt<1>
| output y : SInt<3>
|
| x <= UInt<0>(0)
| y <= SInt<0>(0)
|""".stripMargin

val result = compile(input).circuit
val lines = result.serialize.split('\n').map(_.trim)
assert(lines.contains("x <= UInt<1>(\"h0\")"))
assert(lines.contains("y <= SInt<1>(\"h0\")"))
}
}

Expand Down

0 comments on commit b4054c2

Please sign in to comment.