Skip to content

Commit

Permalink
Merge pull request #31 from OpShin/fix/unique_variables_for_machine
Browse files Browse the repository at this point in the history
Apply a unique variable vistor before evaluating
  • Loading branch information
nielstron committed Jan 2, 2024
2 parents cff4d81 + 05f0e5e commit 794fd30
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 46 deletions.
13 changes: 12 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "uplc"
version = "1.0.0"
version = "1.0.1"
description = "Python implementation of untyped plutus language core"
authors = ["nielstron <n.muendler@web.de>"]
license = "MIT"
Expand All @@ -23,7 +23,7 @@ packages = [{include = "uplc"}]
python = ">=3.8, <3.12"
frozendict = "^2.3.8"
cbor2 = "^5.4.6"
frozenlist = "^1.3.3"
frozenlist2 = "^1.0.0"
rply = "^0.7.8"
pycardano = "^0.9.0"
python-secp256k1-cardano = "^0.2.3"
Expand Down
2 changes: 1 addition & 1 deletion uplc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging


__version__ = "1.0.0"
__version__ = "1.0.1"
__author__ = "nielstron"
__author_email__ = "n.muendler@web.de"
__copyright__ = "Copyright (C) 2023 nielstron"
Expand Down
30 changes: 18 additions & 12 deletions uplc/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import cbor2
import frozendict
import frozenlist
from frozenlist2 import frozenlist
import nacl.exceptions
from _cbor2 import CBOREncoder
from pycardano.crypto.bip32 import BIP32ED25519PublicKey
Expand Down Expand Up @@ -316,7 +316,7 @@ class BuiltinList(Constant):
sample_value: Constant

def __init__(self, values, sample_value=None):
object.__setattr__(self, "values", values)
object.__setattr__(self, "values", frozenlist(values))
if not values:
assert (
sample_value is not None
Expand Down Expand Up @@ -427,7 +427,10 @@ def d_ex_mem(self) -> int:

@dataclass(frozen=True, eq=True)
class PlutusList(PlutusData):
value: Union[List[PlutusData], frozenlist.FrozenList]
value: Union[List[PlutusData], frozenlist]

def __post_init__(self):
object.__setattr__(self, "value", frozenlist(self.value))

def to_cbor(self):
return [d.to_cbor() for d in self.value]
Expand All @@ -446,6 +449,10 @@ def d_ex_mem(self) -> int:
class PlutusMap(PlutusData):
value: Union[Dict[PlutusData, PlutusData], frozendict.frozendict]

def __post_init__(self):
frozen_value = frozendict.frozendict(self.value)
object.__setattr__(self, "value", frozen_value)

def to_cbor(self):
return {k.to_cbor(): v.to_cbor() for k, v in self.value.items()}

Expand All @@ -468,7 +475,10 @@ def d_ex_mem(self) -> int:
@dataclass(frozen=True, eq=True)
class PlutusConstr(PlutusData):
constructor: int
fields: Union[List[PlutusData], frozenlist.FrozenList]
fields: Union[List[PlutusData], frozenlist]

def __post_init__(self):
object.__setattr__(self, "fields", frozenlist(self.fields))

def to_cbor(self):
fields = (
Expand Down Expand Up @@ -560,16 +570,14 @@ def data_from_cbortag(cbor) -> PlutusData:
constructor, fields = cbor.value
else:
raise ValueError(f"Invalid cbor with tag {cbor.tag}")
fields = frozenlist.FrozenList(list(map(data_from_cbortag, fields)))
fields.freeze()
fields = frozenlist(list(map(data_from_cbortag, fields)))
return PlutusConstr(constructor, fields)
if isinstance(cbor, int):
return PlutusInteger(cbor)
if isinstance(cbor, bytes):
return PlutusByteString(cbor)
if isinstance(cbor, list):
entries = frozenlist.FrozenList(list(map(data_from_cbortag, cbor)))
entries.freeze()
entries = frozenlist(list(map(data_from_cbortag, cbor)))
return PlutusList(entries)
if isinstance(cbor, dict):
return PlutusMap(
Expand All @@ -587,16 +595,14 @@ def data_from_cbor(cbor: bytes) -> PlutusData:

def data_from_json_dict(d: dict) -> PlutusData:
if "constructor" in d:
fields = frozenlist.FrozenList([data_from_json_dict(f) for f in d["fields"]])
fields.freeze()
fields = frozenlist([data_from_json_dict(f) for f in d["fields"]])
return PlutusConstr(d["constructor"], fields)
if "int" in d:
return PlutusInteger(d["int"])
if "bytes" in d:
return PlutusByteString(bytes.fromhex(d["bytes"]))
if "list" in d:
entries = frozenlist.FrozenList(list(map(data_from_json_dict, d["list"])))
entries.freeze()
entries = frozenlist(list(map(data_from_json_dict, d["list"])))
return PlutusList(entries)
if "map" in d:
return PlutusMap(
Expand Down
15 changes: 15 additions & 0 deletions uplc/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ def __rsub__(self, other: "Budget") -> "Budget":
def __rmul__(self, other: int) -> "Budget":
return Budget(self.cpu * other, self.memory * other)

def __ge__(self, other: "Budget") -> bool:
return self.cpu >= other.cpu and self.memory >= other.memory

def __gt__(self, other: "Budget") -> bool:
return self.cpu > other.cpu and self.memory > other.memory

def __lt__(self, other: "Budget") -> bool:
return self.cpu < other.cpu and self.memory < other.memory

def __le__(self, other: "Budget") -> bool:
return self.cpu <= other.cpu and self.memory <= other.memory

def __eq__(self, other: "Budget") -> bool:
return self.cpu == other.cpu and self.memory == other.memory

def exhausted(self):
return self.cpu < 0 or self.memory < 0

Expand Down
9 changes: 9 additions & 0 deletions uplc/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy

from .ast import *
from .transformer.unique_variables import UniqueVariableTransformer, FreeVariableError
from .cost_model import CekMachineCostModel, BuiltinCostModel, CekOp, Budget

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -80,6 +81,14 @@ def spend_unbudgeted_steps(self):
# Compute methods

def eval(self, program: Program):
try:
program = UniqueVariableTransformer().visit(program)
except FreeVariableError as e:
return ComputationResult(
e,
[],
Budget(0, 0),
)
self.remaining_budget = copy.copy(self.budget)
self.logs = []
stack = [
Expand Down
7 changes: 5 additions & 2 deletions uplc/optimizer/pre_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ..util import NodeTransformer
from ..ast import Program, AST
from ..machine import Machine
from ..tools import eval

"""
Optimizes code by pre-evaluating each subterm
Expand All @@ -14,7 +14,10 @@ def visit_Program(self, node: Program) -> Program:

def generic_visit(self, node: AST) -> AST:
try:
nc = Machine(node).eval()
nc = eval(node).result
except Exception as e:
nc = node
else:
if isinstance(nc, Exception):
nc = node
return super().generic_visit(nc)
16 changes: 12 additions & 4 deletions uplc/tests/test_acceptance.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,15 @@ def test_acceptance_tests(self, _, dirpath, rewriter):
return
cost = json.loads(cost_content)
expected_spent_budget = Budget(cost["cpu"], cost["mem"])
self.assertEqual(
expected_spent_budget, comp_res.cost, "Program evaluated with wrong cost."
)
# TODO check logs
if rewriter == pre_evaluation.PreEvaluationOptimizer:
self.assertGreaterEqual(
expected_spent_budget,
comp_res.cost,
"Program cost more after preeval rewrite",
)
else:
self.assertEqual(
expected_spent_budget,
comp_res.cost,
"Program evaluated with wrong cost.",
)
53 changes: 29 additions & 24 deletions uplc/tests/test_roundtrips.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import hypothesis
from hypothesis import strategies as hst
import frozenlist as fl
from frozenlist2 import frozenlist
import pyaiken
from parameterized import parameterized

Expand All @@ -18,12 +18,6 @@
from ..util import NodeVisitor


def frozenlist(l):
l = fl.FrozenList(l)
l.freeze()
return l


pos_int = hst.integers(min_value=0, max_value=2**64 - 1)


Expand Down Expand Up @@ -299,6 +293,8 @@ def test_preeval_no_semantic_change(self, p):
try:
orig_res = orig_p
for _ in range(100):
if isinstance(orig_res, Exception):
break
if isinstance(orig_res, BoundStateLambda) or isinstance(
orig_res, ForcedBuiltIn
):
Expand All @@ -307,39 +303,48 @@ def test_preeval_no_semantic_change(self, p):
orig_res = Apply(orig_res, p)
if isinstance(orig_res, BoundStateDelay):
orig_res = Force(orig_res)
orig_res = eval(orig_res)
if not isinstance(orig_res.result, Exception):
orig_res = unique_variables.UniqueVariableTransformer().visit(
orig_res.result
)
else:
orig_res = str(orig_res.result)
orig_res = eval(orig_res).result
if not isinstance(orig_res, Exception):
orig_res = unique_variables.UniqueVariableTransformer().visit(orig_res)
except unique_variables.FreeVariableError:
self.fail(f"Free variable error occurred after evaluation in {code}")
try:
rewrite_res = rewrite_p
for _ in range(100):
if isinstance(rewrite_res, Exception):
break
if isinstance(rewrite_res, BoundStateLambda) or isinstance(
rewrite_res, ForcedBuiltIn
):
p = params.pop(0)
rewrite_res = Apply(rewrite_res, p)
if isinstance(rewrite_res, BoundStateDelay):
rewrite_res = Force(rewrite_res)
rewrite_res = eval(rewrite_res)
if not isinstance(rewrite_res.result, Exception):
rewrite_res = eval(rewrite_res).result
if not isinstance(rewrite_res, Exception):
rewrite_res = unique_variables.UniqueVariableTransformer().visit(
rewrite_res.result
rewrite_res
)
else:
rewrite_res = str(rewrite_res.result)
except unique_variables.FreeVariableError:
self.fail(f"Free variable error occurred after evaluation in {code}")
self.assertEqual(
orig_res,
rewrite_res,
f"Two programs evaluate to different results after optimization in {code}",
)
if not isinstance(rewrite_res, Exception):
if isinstance(orig_res, Exception):
self.assertIsInstance(
orig_res,
RuntimeError,
"Original code resulted in something different than a runtime error (exceeding budget) and rewritten result is ok",
)
self.assertEqual(
orig_res,
rewrite_res,
f"Two programs evaluate to different results after optimization in {code}",
)
else:
self.assertIsInstance(
orig_res,
Exception,
"Rewrite result was exception but orig result is not an exception",
)

@hypothesis.given(uplc_program_valid)
@hypothesis.settings(max_examples=1000, deadline=datetime.timedelta(seconds=10))
Expand Down

0 comments on commit 794fd30

Please sign in to comment.