Files
Gwendolyn/gwendolyn/funcs/other/roll/ast_nodes.py
2024-10-31 22:38:59 +01:00

737 lines
19 KiB
Python

from __future__ import annotations
from random import randint
from rply.token import BaseBox
def show_exp(exp, vtable):
r = exp.show(vtable)
if type(exp) not in [ExpInt,ExpRoll,ExpVar]:
r = f"({r})"
return r
class Exp(BaseBox):
def eval(self, vtable, *args):
return self._eval(vtable.copy(), *args)
def show(self, vtable, *args):
return self._show(vtable.copy(), *args)
def _eval(self, _):
return None
def _show(self, _):
return ""
def includes(self, _):
return False
def __repr__(self) -> str:
return "exp()"
def __eq__(self, other: Exp) -> bool:
return False
class ExpInt(Exp):
def __init__(self, value: int):
self.value = value
def _eval(self, _):
return self.value
def _show(self, _):
return str(self.value)
def includes(self, _):
return False
def __repr__(self) -> str:
return f"exp_int({self.value})"
def __eq__(self, other: Exp) -> bool:
return isinstance(other, ExpInt) and self.value == other.value
class ExpMin(Exp):
def __init__(self, exp1: Exp, exp2: Exp):
self.exp1 = exp1
self.exp2 = exp2
def _eval(self, vtable):
r1 = self.exp1.eval(vtable)
r2 = self.exp2.eval(vtable)
if not (isinstance(r1,int) and isinstance(r2,int)):
return None
return max(r1, r2)
def _show(self, vtable):
r1 = show_exp(self.exp1, vtable)
r2 = show_exp(self.exp2, vtable)
return f"{r1}min{r2}"
def includes(self, var: str):
return self.exp1.includes(var) or self.exp2.includes(var)
def __repr__(self) -> str:
return f"exp_min({self.exp1},{self.exp2})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ExpMin) and
self.exp1 == other.exp1 and
self.exp2 == other.exp2
)
class ExpMax(Exp):
def __init__(self, exp1: Exp, exp2: Exp):
self.exp1 = exp1
self.exp2 = exp2
def _eval(self, vtable):
r1 = self.exp1.eval(vtable)
r2 = self.exp2.eval(vtable)
if not (isinstance(r1,int) and isinstance(r2,int)):
return None
return min(r1, r2)
def _show(self, vtable):
r1 = show_exp(self.exp1, vtable)
r2 = show_exp(self.exp2, vtable)
return f"{r1}max{r2}"
def includes(self, var: str):
return self.exp1.includes(var) or self.exp2.includes(var)
def __repr__(self) -> str:
return f"exp_max({self.exp1},{self.exp2})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ExpMax) and
self.exp1 == other.exp1 and
self.exp2 == other.exp2
)
class ExpBinop(Exp):
def __init__(self, op: str, exp1: Exp, exp2: Exp):
self.op = op
self.exp1 = exp1
self.exp2 = exp2
def _eval(self, vtable):
r1 = self.exp1.eval(vtable)
r2 = self.exp2.eval(vtable)
if not (isinstance(r1,int) and isinstance(r2,int)):
return None
if self.op == "+":
return r1+r2
elif self.op == "-":
return r1-r2
elif self.op == "*":
return r1*r2
elif self.op == "/":
return r1//r2
else:
raise Exception(f"Unknown binop {self.op}")
def _show(self, vtable):
r1 = show_exp(self.exp1, vtable)
r2 = show_exp(self.exp2, vtable)
return f"{r1}{self.op}{r2}"
def includes(self, var: str):
return self.exp1.includes(var) or self.exp2.includes(var)
def __repr__(self) -> str:
return f"exp_binop({self.op}, {self.exp1}, {self.exp2})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ExpBinop) and
self.op == other.op and
self.exp1 == other.exp1 and
self.exp2 == other.exp2
)
class ExpNeg(Exp):
def __init__(self, exp: Exp):
self.exp = exp
def _eval(self, vtable):
r = self.exp.eval(vtable)
if not isinstance(r,int):
return None
return -r
def _show(self, vtable):
r = show_exp(self.exp, vtable)
return f"-{r}"
def includes(self, var: str):
return self.exp.includes(var)
def __repr__(self) -> str:
return f"exp_neg({self.exp})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ExpNeg) and
self.exp == other.exp
)
class ExpLet(Exp):
def __init__(self, var: str, exp1: Exp, exp2: Exp):
self.var = var
self.exp1 = exp1
self.exp2 = exp2
def _eval(self, vtable):
r1 = self.exp1.eval(vtable)
vtable[self.var] = r1
return self.exp2.eval(vtable)
def _show(self, vtable):
r1 = show_exp(self.exp1, vtable)
vtable[self.var] = self.exp1.eval(vtable)
r2 = show_exp(self.exp2, vtable)
if self.exp2.includes(self.var):
return f"let {self.var}={r1} in {r2}"
else:
return r2
def includes(self, var: str):
return self.exp1.includes(var) or self.exp2.includes(var)
def __repr__(self) -> str:
return f"exp_let({self.var}, {self.exp1}, {self.exp2})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ExpLet) and
self.var == other.var and
self.exp1 == other.exp1 and
self.exp2 == other.exp2
)
class ExpIf(Exp):
def __init__(self, exp1: Exp, exp2: Exp, exp3: Exp):
self.exp1 = exp1
self.exp2 = exp2
self.exp3 = exp3
def _eval(self, vtable):
r1 = self.exp1.eval(vtable)
if r1 > 0:
return self.exp2.eval(vtable)
else:
return self.exp3.eval(vtable)
def _show(self, vtable):
r1 = show_exp(self.exp1, vtable)
r2 = show_exp(self.exp2, vtable)
r3 = self.exp3.show(vtable)
return f"if {r1} then {r2} else {r3}"
def includes(self, var: str):
return self.exp1.includes(var) or self.exp2.includes(var) or self.exp3.includes(var)
def __repr__(self) -> str:
return f"exp_if({self.exp1}, {self.exp2}, {self.exp3})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ExpIf) and
self.exp1 == other.exp1 and
self.exp2 == other.exp2 and
self.exp3 == other.exp3
)
class ExpLambda(Exp):
def __init__(self, var: str, exp: Exp):
self.var = var
self.exp = exp
def _eval(self, _):
return (self.exp, self.var)
def _show(self, vtable):
r = show_exp(self.exp, vtable)
return f"\\{self.var} -> {r}"
def includes(self, var: str):
return self.exp.includes(var)
def __repr__(self) -> str:
return f"exp_lambda({self.var}, {self.exp})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ExpLambda) and
self.var == other.var and
self.exp == other.exp
)
class ExpApply(Exp):
def __init__(self, exp1: Exp, exp2: Exp):
self.exp1 = exp1
self.exp2 = exp2
def _eval(self, vtable):
r1 = self.exp1.eval(vtable)
if isinstance(r1, tuple):
r2 = self.exp2.eval(vtable)
vtable[r1[1]] = r2
return r1[0].eval(vtable)
else:
return None
def _show(self, vtable):
r1 = show_exp(self.exp1, vtable)
r2 = self.exp2.show(vtable)
return f"{r1}({r2})"
def includes(self, var: str):
return self.exp1.includes(var) or self.exp2.includes(var)
def __repr__(self) -> str:
return f"exp_apply({self.exp1}, {self.exp2})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ExpApply) and
self.exp1 == other.exp1 and
self.exp2 == other.exp2
)
class ExpVar(Exp):
def __init__(self, var: str):
self.var = var
def _eval(self, vtable):
return vtable[self.var] if self.var in vtable else None
def _show(self, vtable):
return self.var
def includes(self, var: str):
return var == self.var
def __repr__(self) -> str:
return f"exp_var({self.var})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ExpVar) and
self.var == other.var
)
class ComparePoint(Exp):
def __init__(self, comp_op: str, exp: Exp) -> None:
self.comp_op = comp_op
self.exp = exp
def _eval(self, vtable, val: int):
r = self.exp.eval(vtable)
if not isinstance(r,int):
return None
if self.comp_op == "=":
return 1 if val == r else 0
if self.comp_op == "<":
return 1 if val < r else 0
if self.comp_op == "<=":
return 1 if val <= r else 0
if self.comp_op == ">":
return 1 if val <= r else 0
if self.comp_op == ">=":
return 1 if val <= r else 0
else:
raise Exception(f"Unknown binop {self.op}")
def _show(self, vtable):
r = show_exp(self.exp, vtable)
return f"{self.comp_op}{r}"
def includes(self, var: str):
return self.exp.includes(var)
def __repr__(self) -> str:
return f"comp({self.comp_op},{self.exp})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ComparePoint) and
self.comp_op == other.comp_op and
self.exp == other.exp
)
class ExpTest(Exp):
def __init__(self, exp: Exp, comp: ComparePoint) -> None:
self.exp = exp
self.comp = comp
def _eval(self, vtable):
r = self.exp.eval(vtable)
return self.comp.eval(vtable, r)
def _show(self, vtable):
r = show_exp(self.exp, vtable)
c = self.comp.show(vtable)
return f"{r}{c}"
def includes(self, var: str):
return self.exp.includes(var) or self.comp.includes(var)
def __repr__(self) -> str:
return f"test({self.exp},{self.comp})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ExpTest) and
self.exp == other.exp and
self.comp == other.comp
)
class ExpRoll(Exp):
def __init__(self, roll: Roll):
self.roll = roll
def _eval(self, vtable):
return sum(self.roll.eval(vtable))
def _show(self, vtable):
return f"[{','.join(self.roll.show(vtable))}]"
def includes(self, _):
return False
def __repr__(self) -> str:
return f"sum({self.roll})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, ExpRoll) and
self.roll == other.roll
)
class Roll(Exp):
def __init__(self, exp1: Exp, exp2: Exp):
self.exp1 = exp1
self.exp2 = exp2
self.result = None
def _eval(self, vtable):
if self.result is not None:
return self.result
r1 = self.exp1.eval(vtable)
r2 = self.exp2.eval(vtable)
if not (isinstance(r1,int) and isinstance(r2,int)):
return []
self.die_type = r2
self.result = [randint(1,r2) for _ in range(r1)]
self.show_list = [str(i) for i in self.result]
return self.result
def _show(self, vtable):
self.eval(vtable)
return self.show_list
@property
def die(self) -> int:
if hasattr(self, "die_type"):
return self.die_type
elif hasattr(self, "roll"):
return self.roll.die_type
else:
return 0
def __repr__(self) -> str:
return f"roll({self.exp1},{self.exp2})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, Roll) and
self.exp1 == other.exp1 and
self.exp2 == other.exp2
)
class RollKeepHighest(Roll):
def __init__(self, roll: Roll, exp: Exp):
self.roll = roll
self.exp = exp
self.result = None
self.show_list = None
def _eval(self, vtable):
if self.result is not None:
return self.result
r1 = self.roll.eval(vtable)
r2 = self.exp.eval(vtable)
if not ((isinstance(r1,list) and all(isinstance(i,int) for i in r1)) and isinstance(r2,int)):
return []
max_indices = []
for i, n in enumerate(r1):
if len(max_indices) < r2:
max_indices.append(i)
elif not all([n <= r1[j] for j in max_indices]):
max_indices.remove(min(max_indices,key=lambda x: r1[x]))
max_indices.append(i)
self.result = [r1[i] for i in max_indices]
self.show_list = [str(n) if i in max_indices else f"~~{n}~~" for i, n in enumerate(r1)]
return self.result
def __repr__(self) -> str:
return f"kep_highest({self.roll},{self.exp})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, RollKeepHighest) and
self.roll == other.roll and
self.exp == other.exp
)
class RollKeepLowest(Roll):
def __init__(self, roll: Roll, exp: Exp):
self.roll = roll
self.exp = exp
self.result = None
self.show_list = None
def _eval(self, vtable):
if self.result is not None:
return self.result
r1 = self.roll.eval(vtable)
r2 = self.exp.eval(vtable)
if not ((isinstance(r1,list) and all(isinstance(i,int) for i in r1)) and isinstance(r2,int)):
return []
min_indices = []
for i, n in enumerate(r1):
if len(min_indices) < r2:
min_indices.append(i)
elif not all([n >= r1[j] for j in min_indices]):
min_indices.remove(max(min_indices,key=lambda x: r1[x]))
min_indices.append(i)
self.result = [r1[i] for i in min_indices]
self.show_list = [str(n) if i in min_indices else f"~~{n}~~" for i, n in enumerate(r1)]
return self.result
def __repr__(self) -> str:
return f"kep_lowest({self.roll},{self.exp})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, RollKeepLowest) and
self.roll == other.roll and
self.exp == other.exp
)
class RollMin(Roll):
def __init__(self, roll: Roll, exp: Exp):
self.roll = roll
self.exp = exp
self.result = None
self.show_list = None
def _eval(self, vtable):
if self.result is not None:
return self.result
r1 = self.roll.eval(vtable)
r2 = self.exp.eval(vtable)
if not ((isinstance(r1,list) and all(isinstance(i,int) for i in r1)) and isinstance(r2,int)):
return []
self.show_list = []
for i in range(len(r1)):
if r1[i] < r2:
r1[i] = r2
self.show_list.append(f"{r2}^")
else:
self.show_list.append(str(r1[i]))
self.result = r1
return self.result
def __repr__(self) -> str:
return f"min({self.roll},{self.exp})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, RollMin) and
self.roll == other.roll and
self.exp == other.exp
)
class RollMax(Roll):
def __init__(self, roll: Roll, exp: Exp):
self.roll = roll
self.exp = exp
self.result = None
self.show_list = None
def _eval(self, vtable):
if self.result is not None:
return self.result
r1 = self.roll.eval(vtable)
r2 = self.exp.eval(vtable)
if not ((isinstance(r1,list) and all(isinstance(i,int) for i in r1)) and isinstance(r2,int)):
return []
self.show_list = []
for i in range(len(r1)):
if r1[i] > r2:
r1[i] = r2
self.show_list.append(f"{r2}v")
else:
self.show_list.append(str(r1[i]))
self.result = r1
return self.result
def __repr__(self) -> str:
return f"max({self.roll},{self.exp})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, RollMax) and
self.roll == other.roll and
self.exp == other.exp
)
class RollExplode(Roll):
def __init__(self, roll: Roll, comp: ComparePoint = None):
self.roll = roll
self.comp = comp
self.result = None
self.show_list = None
def _eval(self, vtable):
if self.result is not None:
return self.result
r1 = self.roll.eval(vtable)
if not (isinstance(r1,list) and all(isinstance(i,int) for i in r1)):
return []
d = self.die
if self.comp is None:
self.comp = ComparePoint("=", ExpInt(d))
self.result = []
self.show_list = []
def compare(n):
if self.comp.eval(vtable, n):
self.result.append(n)
self.show_list.append(f"{n}!")
compare(randint(1,d))
else:
self.result.append(n)
self.show_list.append(str(n))
for n in r1:
compare(n)
return self.result
def __repr__(self) -> str:
return f"max({self.roll},{self.exp})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, RollMax) and
self.roll == other.roll and
self.exp == other.exp
)
class RollReroll(Roll):
def __init__(self, roll: Roll, comp: ComparePoint = None, once: bool = False):
self.roll = roll
self.comp = comp
self.once = once
self.result = None
self.show_list = None
def _eval(self, vtable):
if self.result is not None:
return self.result
r1 = self.roll.eval(vtable)
if not (isinstance(r1,list) and all(isinstance(i,int) for i in r1)):
return []
d = self.die
if self.comp is None:
self.comp = ComparePoint("=", ExpInt(1))
self.result = []
self.show_list = []
def compare(n, rerolled):
if self.comp.eval(vtable, n) and not (rerolled and self.once):
self.show_list.append(f"~~{n}~~")
compare(randint(1,d), True)
else:
self.result.append(n)
self.show_list.append(str(n))
for n in r1:
compare(n, False)
return self.result
def __repr__(self) -> str:
return f"max({self.roll},{self.exp})"
def __eq__(self, other: Exp) -> bool:
return (
isinstance(other, RollMax) and
self.roll == other.roll and
self.exp == other.exp
)