from __future__ import annotations
import functools
from math import gcd
from itertools import product
from tinygrad.helpers import partition
from typing import (
List,
Dict,
Callable,
Tuple,
Type,
Union,
Optional,
Any,
Iterator,
Set,
)
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
[docs]
def is_sym_int(x: Any) -> bool:
"""
Check if an object is either of type `int` or an instance of the Node class.
Attributes:
x (Any): The object to check.
Returns:
bool: True if x is of type int or an instance of Node, False otherwise.
"""
return isinstance(x, (int, Node))
[docs]
class Node:
"""
The base class for a node in the expression tree. This class represents a single point in the grid of possible expressions.
Attributes:
b (Union[Node, int]): The base node or integer value.
min (int): The minimum value of this node.
max (int): The maximum value of this node.
"""
b: Union[Node, int]
min: int
max: int
[docs]
def render(self, ops=None, ctx=None) -> Any:
"""
Render the expression tree as a string or other object depending on the output type.
Args:
ops (Optional[Dict[Type[Node], Callable]]): The dictionary of rendering functions for each node type. Defaults to render_python.
ctx (Optional[Any]): The context for the rendering function.
Returns:
Any: The rendered object.
"""
if ops is None:
ops = render_python
assert self.__class__ in (Variable, NumNode) or self.min != self.max
return ops[type(self)](self, ops, ctx)
[docs]
def vars(self) -> Set[Variable]:
"""
Get the set of variables in this node.
Returns:
Set[Variable]: The set of variables.
"""
return set()
[docs]
def expand_idx(self) -> VariableOrNum:
"""
Expand a Node into a single Variable or an integer if the underlying Variables are not defined.
Returns:
VariableOrNum: The expanded node.
"""
return next((v for v in self.vars() if v.expr is None), NumNode(0))
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
# expand increments earlier variables faster than later variables (as specified in the argument)
[docs]
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def expand(self, idxs: Optional[Tuple[VariableOrNum, ...]] = None) -> List[Node]:
"""
Expand a Node into List[Node] that enumerates the underlying Variables from min to max.
The expansion increments earlier variables faster than later variables (as specified in the argument).
Args:
idxs (Optional[Tuple[VariableOrNum, ...]]): The indices to expand. Defaults to None.
Returns:
List[Node]: The expanded nodes.
"""
if idxs is None:
idxs = (self.expand_idx(),)
return [
self.substitute(dict(zip(idxs, (NumNode(x) for x in rep))))
for rep in Node.iter_idxs(idxs)
]
[docs]
@staticmethod
def iter_idxs(idxs: Tuple[VariableOrNum, ...]) -> Iterator[Tuple[int, ...]]:
"""
Get an iterator over the indices from min to max for each Variable.
Args:
idxs (Tuple[VariableOrNum, ...]): The indices to iterate over.
Returns:
Iterator[Tuple[int, ...]]: An iterator over the indices.
"""
yield from (
x[::-1]
for x in product(
*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]]
)
)
# substitute Variables with the values in var_vals
[docs]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
"""
Substitute Variables with the values in var_vals.
Args:
var_vals (Dict[VariableOrNum, Node]): The dictionary of variable substitutions.
Returns:
Node: The node with variables substituted.
"""
raise RuntimeError(self.__class__.__name__)
[docs]
def unbind(self) -> Tuple[Node, Optional[int]]:
"""
Unbind the node by replacing any bound Variables with their unbound counterparts.
Returns:
Tuple[Node, Optional[int]]: The unbound node and None.
"""
return (
self.substitute(
{v: v.unbind()[0] for v in self.vars() if v.val is not None}
),
None,
)
@functools.cached_property
def key(self) -> str:
"""
Returns a string representation of the node.
Attributes:
ctx (str): The context for rendering the node. Defaults to "DEBUG".
Returns:
str: A string representation of the node.
"""
return self.render(ctx="DEBUG")
@functools.cached_property
def hash(self) -> int:
"""
Returns a hash value for the node based on its key.
Returns:
int: A hash value of the node's key.
"""
return hash(self.key)
def __repr__(self):
"""
Returns a string representation of the node, used by the REPR context.
Attributes:
ctx (str): The context for rendering the node. Defaults to "REPR".
Returns:
str: A string representation of the node using the REPR context.
"""
return self.render(ctx="REPR")
def __str__(self):
"""
Returns a formatted string representation of the node, surrounded by angle brackets.
Returns:
str: A formatted string representation of the node.
"""
return "<" + self.key + ">"
def __hash__(self):
"""
Returns the hash value of the node based on its key.
Returns:
int: The hash value of the node's key.
"""
return self.hash
def __bool__(self):
"""
Checks if the node is truthy, i.e., not equal to 0.
Returns:
bool: True if the node is not equal to 0, False otherwise.
"""
return not (self.max == self.min == 0)
def __eq__(self, other: object) -> bool:
"""
Checks for equality between this node and another object.
Attributes:
other (object): The object to compare against.
Returns:
bool: True if the objects are equal, False otherwise.
"""
if not isinstance(other, Node):
return NotImplemented
return self.key == other.key
def __neg__(self):
"""
Returns a new node that represents the negation of this node.
Returns:
Node: A new node representing the negation of this node.
"""
return self * -1
def __add__(self, b: Union[Node, int]):
"""
Adds this node to another node or integer value and returns a new node.
Attributes:
b (Union[Node, int]): The object to add to this node.
Returns:
Node: A new node representing the sum of this node and the other object.
"""
return Variable.sum([self, b if isinstance(b, Node) else NumNode(b)])
def __radd__(self, b: int):
"""
Adds an integer value to this node and returns a new node.
Attributes:
b (int): The integer value to add to this node.
Returns:
Node: A new node representing the sum of this node and the integer value.
"""
return self + b
def __sub__(self, b: Union[Node, int]):
"""
Subtracts another node or integer value from this node and returns a new node.
Attributes:
b (Union[Node, int]): The object to subtract from this node.
Returns:
Node: A new node representing the difference of this node and the other object.
"""
return self + -b
def __rsub__(self, b: int):
"""
Subtracts this node from an integer value and returns a new node.
Attributes:
b (int): The integer value to subtract this node from.
Returns:
Node: A new node representing the difference of the integer value and this node.
"""
return -self + b
def __le__(self, b: Union[Node, int]):
"""
Check if the current instance is less than or equal to `b`.
Parameters:
b (Union[Node, int]): The object to compare with. It can be an integer or a Node.
Returns:
bool: True if the current instance is less than or equal to `b`, False otherwise.
"""
return self < (b + 1)
def __gt__(self, b: Union[Node, int]):
"""
Check if the current instance is greater than `b`.
Parameters:
b (Union[Node, int]): The object to compare with. It can be an integer or a Node.
Returns:
bool: True if the current instance is greater than `b`, False otherwise.
"""
return (-self) < (-b)
def __ge__(self, b: Union[Node, int]):
"""
Check if the current instance is greater than or equal to `b`.
Parameters:
b (Union[Node, int]): The object to compare with. It can be an integer or a Node.
Returns:
bool: True if the current instance is greater than or equal to `b`, False otherwise.
"""
return (-self) < (-b + 1)
def __lt__(self, b: Union[Node, int]):
"""
Create a new Node representing the "less than" operation between the current instance and `b`.
Parameters:
b (Union[Node, int]): The object to compare with. It can be an integer or a Node.
Returns:
Node: A new Node representing the "less than" operation.
"""
return create_node(LtNode(self, b))
def __mul__(self, b: Union[Node, int]):
"""
Create a new Node representing the multiplication of the current instance by `b`.
Parameters:
b (Union[Node, int]): The object to multiply with. It can be an integer or a Node.
Returns:
Node: A new Node representing the multiplication.
"""
if b == 0:
return NumNode(0)
if b == 1:
return self
if self.__class__ is NumNode:
return NumNode(self.b * b) if isinstance(b, int) else b * self.b
return (
create_node(MulNode(self, b.b))
if isinstance(b, NumNode)
else create_node(MulNode(self, b))
)
def __rmul__(self, b: int):
"""
Create a new Node representing the multiplication of `b` by the current instance.
Parameters:
b (int): The integer to multiply with.
Returns:
Node: A new Node representing the multiplication.
"""
return self * b
# *** complex ops ***
def __rfloordiv__(self, b: int):
"""
Scalar right floor division.
:param self: The Node object.
:type self: Node
:param b: An integer to perform floor division with.
:type b: int
:return: A new NumNode object with the result of the operation.
:rtype: NumNode
"""
if self.min > b >= 0:
return NumNode(0)
if isinstance(self, NumNode):
return NumNode(b // self.b)
raise RuntimeError(f"not supported: {b} // {self}")
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
"""
Floor division of a Node object by another Node or an integer.
:param self: The Node object to perform floor division on.
:type self: Node
:param b: A Node or an integer to divide the Node object by.
:type b: Union[Node, int]
:param factoring_allowed: Whether factoring is allowed in the operation, defaults to True.
:type factoring_allowed: bool, optional
:return: A new Node object with the result of the floor division operation.
:rtype: Node
"""
if isinstance(b, Node):
if b.__class__ is NumNode:
return self // b.b
if self == b:
return NumNode(1)
if (b - self).min > 0 and self.min >= 0:
return NumNode(0) # b - self simplifies the node
raise RuntimeError(f"not supported: {self} // {b}")
assert b != 0
if b < 0:
return (self // -b) * -1
if b == 1:
return self
# the numerator of div is not allowed to be negative
if self.min < 0:
offset = self.min // b
# factor out an "offset" to make the numerator positive. don't allowing factoring again
return (self + -offset * b).__floordiv__(
b, factoring_allowed=False
) + offset
return create_node(DivNode(self, b))
def __rmod__(self, b: int):
"""
Scalar right modulus.
:param self: The Node object.
:type self: Node
:param b: An integer to perform modulus with.
:type b: int
:return: A new NumNode object with the result of the operation.
:rtype: NumNode
"""
if self.min > b >= 0:
return NumNode(b)
if isinstance(self, NumNode):
return NumNode(b % self.b)
raise RuntimeError(f"not supported: {b} % {self}")
def __mod__(self, b: Union[Node, int]):
"""
Modulus of a Node object by another Node or an integer.
:param self: The Node object to perform modulus on.
:type self: Node
:param b: A Node or an integer to calculate the modulus with respect to.
:type b: Union[Node, int]
:return: A new Node object with the result of the modulus operation.
:rtype: Node
"""
if isinstance(b, Node):
if b.__class__ is NumNode:
return self % b.b
if self == b:
return NumNode(0)
if (b - self).min > 0 and self.min >= 0:
return self # b - self simplifies the node
raise RuntimeError(f"not supported: {self} % {b}")
assert b > 0
if b == 1:
return NumNode(0)
if isinstance(self.max, int) and isinstance(self.min, int):
if self.min >= 0 and self.max < b:
return self
if (self.min // b) == (self.max // b):
return self - (b * (self.min // b))
if self.min < 0:
return (self - ((self.min // b) * b)) % b
return create_node(ModNode(self, b))
[docs]
@staticmethod
def sum(nodes: List[Node]) -> Node:
"""
Calculate the sum of nodes with max or min attributes.
:param nodes: Input list of nodes to calculate sum for.
:type nodes: List[Node]
:return: Summed node calculated from input nodes.
:rtype: Node
"""
nodes = [x for x in nodes if x.max or x.min]
if not nodes:
return NumNode(0)
if len(nodes) == 1:
return nodes[0]
mul_groups: Dict[Node, int] = {}
num_node_sum = 0
for node in SumNode(nodes).flat_components:
"""
Iterate through flat components of SumNode generated from input nodes.
:param node: Current element being processed from SumNode.flat_components.
:type node: Node
"""
if node.__class__ is NumNode:
num_node_sum += node.b
elif node.__class__ is MulNode:
mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b
else:
mul_groups[node] = mul_groups.get(node, 0) + 1
new_nodes = [
MulNode(a, b_sum) if b_sum != 1 else a
for a, b_sum in mul_groups.items()
if b_sum != 0
]
if num_node_sum:
new_nodes.append(NumNode(num_node_sum))
return (
create_rednode(SumNode, new_nodes)
if len(new_nodes) > 1
else new_nodes[0]
if len(new_nodes) == 1
else NumNode(0)
)
[docs]
@staticmethod
def ands(nodes: List[Node]) -> Node:
"""
This function takes a list of nodes as input and returns a single node.
Attributes:
nodes (List[Node]): A list of nodes to be processed.
Returns:
Node: A single node that represents the logical AND of all input nodes. If any node in the input list is False, it will return a NumNode(0). If all nodes are True, it will return the last node in the list. If there's only one unique node (excluding NumNode(1)), it will return this node.
"""
if not nodes:
return NumNode(1)
if len(nodes) == 1:
return nodes[0]
if any(not x for x in nodes):
return NumNode(0)
# filter 1s
nodes = [x for x in nodes if x.min != x.max]
return (
create_rednode(AndNode, nodes)
if len(nodes) > 1
else (nodes[0] if len(nodes) == 1 else NumNode(1))
)
# 4 basic node types
[docs]
class Variable(Node):
"""
Variable class.
Attributes:
expr (Optional[str]): Optional string representation of the variable.
nmin (int): Minimum value for the variable.
nmax (int): Maximum value for the variable.
"""
def __new__(cls, expr: Optional[str], nmin: int, nmax: int):
"""
Create a new Variable object.
Args:
cls: Class reference.
expr (Optional[str]): Optional string representation of the variable.
nmin (int): Minimum value for the variable.
nmax (int): Maximum value for the variable.
Returns:
Variable: A new Variable object if nmin is not equal to nmax, otherwise a NumNode object.
Raises:
AssertionError: If nmin is less than 0 or greater than nmax.
"""
assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
if nmin == nmax:
return NumNode(nmin)
return super().__new__(cls)
def __init__(self, expr: Optional[str], nmin: int, nmax: int):
"""
Initialize a Variable object.
Args:
self: Self reference.
expr (Optional[str]): Optional string representation of the variable.
nmin (int): Minimum value for the variable.
nmax (int): Maximum value for the variable.
"""
self.expr, self.min, self.max = expr, nmin, nmax
self._val: Optional[int] = None
@property
def val(self):
"""
Get the value of the Variable.
Args:
self: Self reference.
Returns:
int: The value of the Variable.
Raises:
AssertionError: If the Variable is not bound to a value.
"""
assert (
self._val is not None
), f"Variable isn't bound, can't access val of {self}"
return self._val
[docs]
def bind(self, val):
"""
Bind the Variable to a specific value.
Args:
self: Self reference.
val (int): The value to bind the Variable to.
Returns:
Variable: The bound Variable.
Raises:
AssertionError: If the Variable is already bound or the value is outside the specified range.
"""
assert (
self._val is None and self.min <= val <= self.max
), f"cannot bind {val} to {self}"
self._val = val
return self
[docs]
def unbind(self) -> Tuple[Variable, int]:
"""
Unbind the Variable from its value and return a new Variable with the same properties and its value.
Args:
self: Self reference.
Returns:
Tuple[Variable, int]: A tuple containing a new Variable object and the value it was bound to.
Raises:
AssertionError: If the Variable is not bound.
"""
assert self.val is not None, f"cannot unbind {self}"
return Variable(self.expr, self.min, self.max), self.val
[docs]
def vars(self):
"""
Get a set of the Variable.
Args:
self: Self reference.
Returns:
Set[Variable]: A set containing the Variable.
"""
return {self}
[docs]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
"""
Substitute the Variable with its value or return itself if not in var_vals.
Args:
self: Self reference.
var_vals (Dict[VariableOrNum, Node]): A dictionary mapping Variables to their values.
Returns:
Node: The substituted Node.
"""
return var_vals[self] if self in var_vals else self
[docs]
class NumNode(Node):
"""
Node class to represent a number.
Attributes:
b (int): The numeric value of the node.
min (int): Minimum possible value for this node.
max (int): Maximum possible value for this node.
"""
def __init__(self, num: int):
"""
Initialize a new NumNode with the given number.
Args:
num (int): The numeric value to assign to this node.
Raises:
AssertionError: If `num` is not an integer.
"""
assert isinstance(num, int), f"{num} is not an int"
self.b: int = num
self.min, self.max = num, num
[docs]
def bind(self, val):
"""
Bind the value of this node to another value.
Args:
val: The value to bind this node to.
Returns:
NumNode: This node with its value bound to `val`.
Raises:
AssertionError: If `self.b` is not equal to `val`.
"""
assert self.b == val, f"cannot bind {val} to {self}"
return self
def __eq__(self, other):
"""
Compare this node for equality with another object.
Args:
other: The object to compare against.
Returns:
bool: `True` if this node is equal to `other`, `False` otherwise.
"""
return self.b == other
def __hash__(self):
"""
Get the hash value of this node.
Returns:
int: The hash value of this node.
Notes:
This method is needed because `__eq__` is overridden.
"""
return self.hash # needed with __eq__ override
[docs]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
"""
Substitute variables in this node with their corresponding values.
Args:
var_vals (Dict[VariableOrNum, Node]): A dictionary mapping variables to their substituted values.
Returns:
NumNode: This node with all its variables substituted by their values.
"""
return self
[docs]
def create_node(ret: Node):
"""
Create a node object based on the provided parameters.
Args:
ret (Node): The node to be created.
Returns:
Union[NumNode, Node]: If min is equal to max, return NumNode with value min. Else, return the inputted node.
Raises:
AssertionError: If minimum value is greater than maximum value.
"""
assert (
ret.min <= ret.max
), f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
if ret.min == ret.max:
return NumNode(ret.min)
return ret
[docs]
class OpNode(Node):
"""
An operation node class that represents a node with an operation involving two nodes or a node and an integer value.
Attributes:
a (Node): The first input node.
b (Union[Node, int]): The second input node or an integer value.
"""
def __init__(self, a: Node, b: Union[Node, int]):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
[docs]
def vars(self):
"""
Get the set of variables involved in this operation node.
Returns:
Set[Node]: A set of nodes representing the variables involved in this operation.
"""
return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
[docs]
def get_bounds(self) -> Tuple[int, int]:
"""
Get the bounds of this operation node. This method should be implemented in any child classes.
Raises:
NotImplementedError: If this method is not overridden by a child class.
"""
raise NotImplementedError("must be implemented")
[docs]
class LtNode(OpNode):
"""
This class represents the lower than operator for nodes.
Attributes:
a (Node): The first node to compare.
b (Union[Node, int]): The second node or number to compare.
"""
def __floordiv__(self, b: Union[Node, int], _=False):
"""
Check if the floor division of a by b is less than the floor division of b by b.
Args:
self (LtNode): The instance of LtNode.
b (Union[Node, int]): The second node or number to compare.
Returns:
bool: True if the floor division of a by b is less than the floor division of b by b, False otherwise.
"""
return (self.a // b) < (self.b // b)
[docs]
def get_bounds(self) -> Tuple[int, int]:
"""
Get the bounds of the LtNode.
Args:
self (LtNode): The instance of LtNode.
Returns:
Tuple[int, int]: A tuple representing the lower and upper bounds.
"""
if isinstance(self.b, int):
return (
(1, 1)
if self.a.max < self.b
else (0, 0)
if self.a.min >= self.b
else (0, 1)
)
return (
(1, 1)
if self.a.max < self.b.min
else (0, 0)
if self.a.min >= self.b.max
else (0, 1)
)
[docs]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
"""
Substitute variables in the LtNode with their corresponding values.
Args:
self (LtNode): The instance of LtNode.
var_vals (Dict[VariableOrNum, Node]): A dictionary mapping variables to their values.
Returns:
Node: A new node with the variables substituted by their corresponding values.
"""
return self.a.substitute(var_vals) < (
self.b if isinstance(self.b, int) else self.b.substitute(var_vals)
)
[docs]
class MulNode(OpNode):
"""
This class represents a multiplication node.
Attributes:
a (Node): The first operand in the multiplication operation.
b (Union[Node, int]): The second operand in the multiplication operation.
"""
def __lt__(self, b: Union[Node, int]):
"""
This method compares two nodes using the less-than operator.
Args:
b (Union[Node, int]): The other operand in the comparison operation.
Returns:
Node: A new node that represents the result of the comparison operation.
"""
if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1:
return Node.__lt__(self, b)
sgn = 1 if self.b > 0 else -1
return Node.__lt__(self.a * sgn, (b + abs(self.b) - 1) // abs(self.b))
def __mul__(self, b: Union[Node, int]):
"""
This method multiplies the current node with another operand.
Args:
b (Union[Node, int]): The other operand in the multiplication operation.
Returns:
Node: A new node that represents the result of the multiplication operation.
"""
return self.a * (self.b * b) # two muls in one mul
def __floordiv__(
self, b: Union[Node, int], factoring_allowed=False
): # NOTE: mod negative isn't handled right
"""
This method performs a floor division operation between the current node and another operand.
Args:
b (Union[Node, int]): The other operand in the floor division operation.
factoring_allowed (bool): Whether or not factoring is allowed. Defaults to False.
Returns:
Node: A new node that represents the result of the floor division operation.
"""
if self.b % b == 0:
return self.a * (self.b // b)
if b % self.b == 0 and self.b > 0:
return self.a // (b // self.b)
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: Union[Node, int]):
"""
This method performs a modulo operation between the current node and another operand.
Args:
b (Union[Node, int]): The other operand in the modulo operation.
Returns:
Node: A new node that represents the result of the modulo operation.
"""
a = self.a * (self.b % b)
return Node.__mod__(a, b)
[docs]
def get_bounds(self) -> Tuple[int, int]:
"""
This method returns the lower and upper bounds of the current node.
Returns:
Tuple[int, int]: A tuple containing the lower and upper bounds.
"""
return (
(self.a.min * self.b, self.a.max * self.b)
if self.b >= 0
else (self.a.max * self.b, self.a.min * self.b)
)
[docs]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
"""
This method substitutes variables in the current node with their corresponding values.
Args:
var_vals (Dict[VariableOrNum, Node]): A dictionary containing variable-value pairs.
Returns:
Node: A new node with variables replaced by their corresponding values.
"""
return self.a.substitute(var_vals) * (
self.b if isinstance(self.b, int) else self.b.substitute(var_vals)
)
[docs]
class DivNode(OpNode):
"""
The DivNode class is a subclass of OpNode. It defines the behavior of division operations on nodes.
Attributes:
a (Node): The dividend node.
b (int): The divisor integer.
"""
def __floordiv__(self, b: Union[Node, int], _=False):
"""
Define the floor division operation between a DivNode and another node or an integer.
:param b: The divisor, which can be a Node or an integer.
:return: A new Node as result of the floor division operation.
"""
return self.a // (self.b * b) # two divs is one div
[docs]
def get_bounds(self) -> Tuple[int, int]:
"""
Get the lower and upper bounds of a DivNode.
:return: A tuple with the lower and upper bounds.
"""
assert self.a.min >= 0 and isinstance(self.b, int)
return self.a.min // self.b, self.a.max // self.b
[docs]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
"""
Substitute the variables in a DivNode with their corresponding nodes.
:param var_vals: A dictionary containing the variables as keys and the corresponding nodes as values.
:return: A new DivNode with the variables substituted by their nodes.
"""
return self.a.substitute(var_vals) // self.b
[docs]
class ModNode(OpNode):
"""
The ModNode class is a subclass of OpNode. It defines the behavior of modulo operations on nodes.
Attributes:
a (Node): The dividend node.
b (int): The divisor integer.
"""
def __mod__(self, b: Union[Node, int]):
"""
Define the modulo operation between a ModNode and another node or an integer.
:param b: The divisor, which can be a Node or an integer.
:return: A new Node as result of the modulo operation.
"""
if isinstance(b, Node) or isinstance(self.b, Node):
return Node.__mod__(self, b)
return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b)
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
"""
Define the floor division operation between a ModNode and another node or an integer.
:param b: The divisor, which can be a Node or an integer.
:return: A new Node as result of the floor division operation.
"""
if self.b % b == 0:
return (self.a // b) % (self.b // b) # put the div inside mod
return Node.__floordiv__(self, b, factoring_allowed)
[docs]
def get_bounds(self) -> Tuple[int, int]:
"""
Get the lower and upper bounds of a ModNode.
:return: A tuple with the lower and upper bounds.
"""
assert self.a.min >= 0 and isinstance(self.b, int)
return (
(0, self.b - 1)
if self.a.max - self.a.min >= self.b
or (self.a.min != self.a.max and self.a.min % self.b >= self.a.max % self.b)
else (self.a.min % self.b, self.a.max % self.b)
)
[docs]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
"""
Substitute the variables in a ModNode with their corresponding nodes.
:param var_vals: A dictionary containing the variables as keys and the corresponding nodes as values.
:return: A new ModNode with the variables substituted by their nodes.
"""
return self.a.substitute(var_vals) % self.b
[docs]
class RedNode(Node):
"""
Base class for all "red" nodes.
Attributes:
nodes (List[Node]): The list of child nodes.
"""
def __init__(self, nodes: List[Node]):
"""
Initializes a new instance of the RedNode class.
Parameters:
nodes (List[Node]): The list of child nodes.
"""
self.nodes = nodes
[docs]
def vars(self) -> Set[Variable]:
"""
Returns the union of all variables used in the child nodes.
Returns:
Set[Variable]: A set containing all unique variables used by the child nodes.
"""
return set.union(*[x.vars() for x in self.nodes], set())
[docs]
class SumNode(RedNode):
"""
SumNode class for representing sum of nodes.
Attributes:
__mul__ (method): Distribute multiplication over summation.
__floordiv__ (method): Perform floor division with consideration to factoring.
"""
[docs]
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __mul__(self, b: Union[Node, int]):
"""
Multiply a SumNode with a Node or an integer.
Args:
self (SumNode): The SumNode object.
b (Union[Node, int]): A Node or an integer to multiply the SumNode by.
Returns:
Node: Result of multiplication operation.
"""
return Node.sum([x * b for x in self.nodes]) # distribute mul into sum
[docs]
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
"""
Perform floor division on a SumNode.
The operation considers factoring and fully dividing nodes before performing the operation.
Args:
self (SumNode): The SumNode object.
b (Union[Node, int]): A Node or an integer to divide the SumNode by.
factoring_allowed (bool, optional): Whether factoring is allowed in the division. Defaults to True.
Returns:
Node: Result of floor division operation.
"""
fully_divided: List[Node] = []
rest: List[Node] = []
if isinstance(b, SumNode):
nu_num = sum(
node.b for node in self.flat_components if node.__class__ is NumNode
)
de_num = sum(
node.b for node in b.flat_components if node.__class__ is NumNode
)
if nu_num > 0 and de_num and (d := nu_num // de_num) > 0:
return NumNode(d) + (self - b * d) // b
if isinstance(b, Node):
for x in self.flat_components:
if x % b == 0:
fully_divided.append(x // b)
else:
rest.append(x)
if (sum_fully_divided := create_rednode(SumNode, fully_divided)) != 0:
return sum_fully_divided + create_rednode(SumNode, rest) // b
return Node.__floordiv__(self, b, False)
if b == 1:
return self
if not factoring_allowed:
return Node.__floordiv__(self, b, factoring_allowed)
fully_divided, rest = [], []
_gcd = b
divisor = 1
for x in self.flat_components:
if x.__class__ in (NumNode, MulNode):
if x.b % b == 0:
fully_divided.append(x // b)
else:
rest.append(x)
if isinstance(x.b, int):
_gcd = gcd(_gcd, x.b)
if x.__class__ == MulNode and divisor == 1 and b % x.b == 0:
divisor = x.b
else:
_gcd = 1
else:
rest.append(x)
_gcd = 1
if _gcd > 1:
return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // (
b // _gcd
)
if divisor > 1:
return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (
b // divisor
)
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def __mod__(self, b: Union[Node, int]):
"""
Perform modulo operation with the given node or integer.
This method first checks if `b` is an instance of `SumNode`. If so, it calculates the sum of `node.b` for all `NumNode` instances in `self.flat_components` and `b.flat_components`, and performs a modulo operation with the calculated values.
Next, if `b` is an instance of `Node` and `(b - self).min > 0`, it returns `self`. Otherwise, it creates a new list of nodes by iterating over the nodes in `self.nodes`. For each node:
- If it's a `NumNode`, it appends a new `NumNode` with the modulo operation result to `new_nodes`.
- If it's a `MulNode`, it appends the product of `x.a` and `(x.b % b)` to `new_nodes`.
- Otherwise, it appends the node itself to `new_nodes`.
Finally, it returns the result of the modulo operation between the sum of all nodes in `new_nodes` and `b`.
:param b: The node or integer to perform the modulo operation with.
:type b: Union[Node, int]
:return: The result of the modulo operation.
"""
if isinstance(b, SumNode):
nu_num = sum(
node.b for node in self.flat_components if node.__class__ is NumNode
)
de_num = sum(
node.b for node in b.flat_components if node.__class__ is NumNode
)
if nu_num > 0 and de_num and (d := nu_num // de_num) > 0:
return (self - b * d) % b
if isinstance(b, Node) and (b - self).min > 0:
return self # b - self simplifies the node
new_nodes: List[Node] = []
for x in self.nodes:
if x.__class__ is NumNode:
new_nodes.append(NumNode(x.b % b))
elif isinstance(x, MulNode):
new_nodes.append(x.a * (x.b % b))
else:
new_nodes.append(x)
return Node.__mod__(Node.sum(new_nodes), b)
def __lt__(self, b: Union[Node, int]):
"""
Perform a less-than comparison between this node and the given node or integer.
If `b` is an integer, it first calculates the difference between `b` and the sum of all `NumNode` instances in `self.nodes`. Then, it creates two lists: `muls`, containing nodes that are instances of `MulNode` with positive values, and `others`, containing other nodes.
If there are any elements in `muls` and the sum of all `others` is greater than or equal to 0 but less than a calculated greatest common divisor (GCD), it returns True. Otherwise, it performs a less-than comparison between this node (`lhs`) and `b`.
:param b: The node or integer to compare with.
:type b: Union[Node, int]
:return: The result of the less-than comparison.
"""
lhs: Node = self
if isinstance(b, int):
new_sum = []
for x in self.nodes:
# TODO: should we just force the last one to always be the number
if isinstance(x, NumNode):
b -= x.b
else:
new_sum.append(x)
lhs = Node.sum(new_sum)
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
assert all(
not isinstance(node, MulNode) or isinstance(node.b, int)
for node in nodes
), "not supported"
muls, others = partition(
nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b
)
if muls:
# NOTE: gcd in python 3.8 takes exactly 2 args
mul_gcd = b
for x in muls:
mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
all_others = Variable.sum(others)
if all_others.min >= 0 and all_others.max < mul_gcd:
lhs, b = (
Variable.sum([mul // mul_gcd for mul in muls]),
b // mul_gcd,
)
return Node.__lt__(lhs, b)
[docs]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
"""
Substitute the variables in the expression with their corresponding values.
:param var_vals: A dictionary containing variable-value pairs.
:type var_vals: Dict[VariableOrNum, Node]
:return: The resulting node after substitution.
:rtype: Node
"""
return Variable.sum([node.substitute(var_vals) for node in self.nodes])
@property
def flat_components(self): # recursively expand sumnode components
"""
Get the flattened components of a SumNode.
If a component is an instance of SumNode, it will be expanded; otherwise, it remains unchanged. This method is used to flatten a nested sum of nodes into a flat list of nodes.
:return: The flattened components of the SumNode.
:rtype: List[Node]
"""
new_nodes = []
for x in self.nodes:
new_nodes += x.flat_components if isinstance(x, SumNode) else [x]
return new_nodes
[docs]
class AndNode(RedNode):
"""
AndNode class for creating AND operations on a list of nodes.
Attributes:
RedNode (Class): Parent class for this class.
"""
def __floordiv__(self, b: Union[Node, int], _=True):
"""
Overrides the floor division magic method to perform AND operations on nodes.
Parameters:
self (AndNode): The instance of AndNode.
b (Union[Node, int]): The divisor for the floor division operation.
_ (bool): Placeholder variable to maintain consistency with parent class method signature.
Returns:
Node: A new node resulting from performing AND operations on all nodes in self.nodes.
"""
return Variable.ands([x // b for x in self.nodes])
[docs]
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
"""
Substitutes variables in the AND operation with their corresponding values provided in var_vals.
Parameters:
self (AndNode): The instance of AndNode.
var_vals (Dict[VariableOrNum, Node]): A dictionary containing variable-value pairs for substitution.
Returns:
Node: A new node with variables substituted by their corresponding values.
"""
subed = []
for node in self.nodes:
if not (sub := node.substitute(var_vals)):
return NumNode(0)
subed.append(sub)
return Variable.ands(subed)
[docs]
def create_rednode(typ: Type[RedNode], nodes: List[Node]):
"""
Create a new red node.
:param typ: The type of the red node to be created.
:type typ: Type[RedNode]
:param nodes: The list of nodes to be used for creation.
:type nodes: List[Node]
:return: A new red node created with given type and nodes.
"""
ret = typ(nodes)
if typ == SumNode:
ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes]))
elif typ == AndNode:
ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
return create_node(ret)
[docs]
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str:
"""
Render a symbolic representation of the given node or integer.
:param a: The node or integer to be rendered.
:type a: Union[Node, int]
:param ops: Optional operations for rendering. Default is None.
:type ops: Any
:param ctx: Optional context for rendering. Default is None.
:type ctx: Any
:return: A string representation of the given node or integer.
:rtype: str
"""
return str(a) if isinstance(a, int) else a.render(ops, ctx)
[docs]
def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int:
"""
Perform inference on the given node or integer with the provided variable values.
:param a: The node or integer to perform inference on.
:type a: Union[Node, int]
:param var_vals: A dictionary mapping variables to their respective integer values.
:type var_vals: Dict[Variable, int]
:return: An integer result of the inference.
:rtype: int
"""
if isinstance(a, (int, float)):
return a
ret = a.substitute({k: NumNode(v) for k, v in var_vals.items()})
assert isinstance(
ret, NumNode
), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
return ret.b
# symbolic int
sint = Union[Node, int]
VariableOrNum = Union[Variable, NumNode]
render_python: Dict[Type, Callable] = {
Variable: lambda self, ops, ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]"
if ctx == "DEBUG"
else (
f"Variable('{self.expr}', {self.min}, {self.max})"
+ (f".bind({self.val})" if self._val is not None else "")
if ctx == "REPR"
else f"{self.expr}"
),
NumNode: lambda self, ops, ctx: f"NumNode({self.b})"
if ctx == "REPR"
else f"{self.b}",
MulNode: lambda self, ops, ctx: f"({sym_render(self.b,ops,ctx)}*{self.a.render(ops,ctx)})"
if isinstance(self.a, Variable)
and isinstance(self.b, Variable)
and self.a.expr
and self.b.expr
and self.b.expr < self.a.expr
else f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
DivNode: lambda self, ops, ctx: f"({self.a.render(ops,ctx)}//{self.b})",
ModNode: lambda self, ops, ctx: f"({self.a.render(ops,ctx)}%{self.b})",
LtNode: lambda self, ops, ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
SumNode: lambda self, ops, ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
AndNode: lambda self, ops, ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
}