Source code for tinygrad.shape.symbolic

from __future__ import annotations
import functools
from math import gcd
from itertools import product
from tinygrad.helpers import partition
from typing import (
    List,
    Dict,
    Callable,
    Tuple,
    Type,
    Union,
    Optional,
    Any,
    Iterator,
    Set,
)

# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod


[docs] def is_sym_int(x: Any) -> bool: """ Check if an object is either of type `int` or an instance of the Node class. Attributes: x (Any): The object to check. Returns: bool: True if x is of type int or an instance of Node, False otherwise. """ return isinstance(x, (int, Node))
[docs] class Node: """ The base class for a node in the expression tree. This class represents a single point in the grid of possible expressions. Attributes: b (Union[Node, int]): The base node or integer value. min (int): The minimum value of this node. max (int): The maximum value of this node. """ b: Union[Node, int] min: int max: int
[docs] def render(self, ops=None, ctx=None) -> Any: """ Render the expression tree as a string or other object depending on the output type. Args: ops (Optional[Dict[Type[Node], Callable]]): The dictionary of rendering functions for each node type. Defaults to render_python. ctx (Optional[Any]): The context for the rendering function. Returns: Any: The rendered object. """ if ops is None: ops = render_python assert self.__class__ in (Variable, NumNode) or self.min != self.max return ops[type(self)](self, ops, ctx)
[docs] def vars(self) -> Set[Variable]: """ Get the set of variables in this node. Returns: Set[Variable]: The set of variables. """ return set()
[docs] def expand_idx(self) -> VariableOrNum: """ Expand a Node into a single Variable or an integer if the underlying Variables are not defined. Returns: VariableOrNum: The expanded node. """ return next((v for v in self.vars() if v.expr is None), NumNode(0))
# expand a Node into List[Node] that enumerates the underlying Variables from min to max # expand increments earlier variables faster than later variables (as specified in the argument)
[docs] @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def expand(self, idxs: Optional[Tuple[VariableOrNum, ...]] = None) -> List[Node]: """ Expand a Node into List[Node] that enumerates the underlying Variables from min to max. The expansion increments earlier variables faster than later variables (as specified in the argument). Args: idxs (Optional[Tuple[VariableOrNum, ...]]): The indices to expand. Defaults to None. Returns: List[Node]: The expanded nodes. """ if idxs is None: idxs = (self.expand_idx(),) return [ self.substitute(dict(zip(idxs, (NumNode(x) for x in rep)))) for rep in Node.iter_idxs(idxs) ]
[docs] @staticmethod def iter_idxs(idxs: Tuple[VariableOrNum, ...]) -> Iterator[Tuple[int, ...]]: """ Get an iterator over the indices from min to max for each Variable. Args: idxs (Tuple[VariableOrNum, ...]): The indices to iterate over. Returns: Iterator[Tuple[int, ...]]: An iterator over the indices. """ yield from ( x[::-1] for x in product( *[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]] ) )
# substitute Variables with the values in var_vals
[docs] def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: """ Substitute Variables with the values in var_vals. Args: var_vals (Dict[VariableOrNum, Node]): The dictionary of variable substitutions. Returns: Node: The node with variables substituted. """ raise RuntimeError(self.__class__.__name__)
[docs] def unbind(self) -> Tuple[Node, Optional[int]]: """ Unbind the node by replacing any bound Variables with their unbound counterparts. Returns: Tuple[Node, Optional[int]]: The unbound node and None. """ return ( self.substitute( {v: v.unbind()[0] for v in self.vars() if v.val is not None} ), None, )
@functools.cached_property def key(self) -> str: """ Returns a string representation of the node. Attributes: ctx (str): The context for rendering the node. Defaults to "DEBUG". Returns: str: A string representation of the node. """ return self.render(ctx="DEBUG") @functools.cached_property def hash(self) -> int: """ Returns a hash value for the node based on its key. Returns: int: A hash value of the node's key. """ return hash(self.key) def __repr__(self): """ Returns a string representation of the node, used by the REPR context. Attributes: ctx (str): The context for rendering the node. Defaults to "REPR". Returns: str: A string representation of the node using the REPR context. """ return self.render(ctx="REPR") def __str__(self): """ Returns a formatted string representation of the node, surrounded by angle brackets. Returns: str: A formatted string representation of the node. """ return "<" + self.key + ">" def __hash__(self): """ Returns the hash value of the node based on its key. Returns: int: The hash value of the node's key. """ return self.hash def __bool__(self): """ Checks if the node is truthy, i.e., not equal to 0. Returns: bool: True if the node is not equal to 0, False otherwise. """ return not (self.max == self.min == 0) def __eq__(self, other: object) -> bool: """ Checks for equality between this node and another object. Attributes: other (object): The object to compare against. Returns: bool: True if the objects are equal, False otherwise. """ if not isinstance(other, Node): return NotImplemented return self.key == other.key def __neg__(self): """ Returns a new node that represents the negation of this node. Returns: Node: A new node representing the negation of this node. """ return self * -1 def __add__(self, b: Union[Node, int]): """ Adds this node to another node or integer value and returns a new node. Attributes: b (Union[Node, int]): The object to add to this node. Returns: Node: A new node representing the sum of this node and the other object. """ return Variable.sum([self, b if isinstance(b, Node) else NumNode(b)]) def __radd__(self, b: int): """ Adds an integer value to this node and returns a new node. Attributes: b (int): The integer value to add to this node. Returns: Node: A new node representing the sum of this node and the integer value. """ return self + b def __sub__(self, b: Union[Node, int]): """ Subtracts another node or integer value from this node and returns a new node. Attributes: b (Union[Node, int]): The object to subtract from this node. Returns: Node: A new node representing the difference of this node and the other object. """ return self + -b def __rsub__(self, b: int): """ Subtracts this node from an integer value and returns a new node. Attributes: b (int): The integer value to subtract this node from. Returns: Node: A new node representing the difference of the integer value and this node. """ return -self + b def __le__(self, b: Union[Node, int]): """ Check if the current instance is less than or equal to `b`. Parameters: b (Union[Node, int]): The object to compare with. It can be an integer or a Node. Returns: bool: True if the current instance is less than or equal to `b`, False otherwise. """ return self < (b + 1) def __gt__(self, b: Union[Node, int]): """ Check if the current instance is greater than `b`. Parameters: b (Union[Node, int]): The object to compare with. It can be an integer or a Node. Returns: bool: True if the current instance is greater than `b`, False otherwise. """ return (-self) < (-b) def __ge__(self, b: Union[Node, int]): """ Check if the current instance is greater than or equal to `b`. Parameters: b (Union[Node, int]): The object to compare with. It can be an integer or a Node. Returns: bool: True if the current instance is greater than or equal to `b`, False otherwise. """ return (-self) < (-b + 1) def __lt__(self, b: Union[Node, int]): """ Create a new Node representing the "less than" operation between the current instance and `b`. Parameters: b (Union[Node, int]): The object to compare with. It can be an integer or a Node. Returns: Node: A new Node representing the "less than" operation. """ return create_node(LtNode(self, b)) def __mul__(self, b: Union[Node, int]): """ Create a new Node representing the multiplication of the current instance by `b`. Parameters: b (Union[Node, int]): The object to multiply with. It can be an integer or a Node. Returns: Node: A new Node representing the multiplication. """ if b == 0: return NumNode(0) if b == 1: return self if self.__class__ is NumNode: return NumNode(self.b * b) if isinstance(b, int) else b * self.b return ( create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b)) ) def __rmul__(self, b: int): """ Create a new Node representing the multiplication of `b` by the current instance. Parameters: b (int): The integer to multiply with. Returns: Node: A new Node representing the multiplication. """ return self * b # *** complex ops *** def __rfloordiv__(self, b: int): """ Scalar right floor division. :param self: The Node object. :type self: Node :param b: An integer to perform floor division with. :type b: int :return: A new NumNode object with the result of the operation. :rtype: NumNode """ if self.min > b >= 0: return NumNode(0) if isinstance(self, NumNode): return NumNode(b // self.b) raise RuntimeError(f"not supported: {b} // {self}") def __floordiv__(self, b: Union[Node, int], factoring_allowed=True): """ Floor division of a Node object by another Node or an integer. :param self: The Node object to perform floor division on. :type self: Node :param b: A Node or an integer to divide the Node object by. :type b: Union[Node, int] :param factoring_allowed: Whether factoring is allowed in the operation, defaults to True. :type factoring_allowed: bool, optional :return: A new Node object with the result of the floor division operation. :rtype: Node """ if isinstance(b, Node): if b.__class__ is NumNode: return self // b.b if self == b: return NumNode(1) if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node raise RuntimeError(f"not supported: {self} // {b}") assert b != 0 if b < 0: return (self // -b) * -1 if b == 1: return self # the numerator of div is not allowed to be negative if self.min < 0: offset = self.min // b # factor out an "offset" to make the numerator positive. don't allowing factoring again return (self + -offset * b).__floordiv__( b, factoring_allowed=False ) + offset return create_node(DivNode(self, b)) def __rmod__(self, b: int): """ Scalar right modulus. :param self: The Node object. :type self: Node :param b: An integer to perform modulus with. :type b: int :return: A new NumNode object with the result of the operation. :rtype: NumNode """ if self.min > b >= 0: return NumNode(b) if isinstance(self, NumNode): return NumNode(b % self.b) raise RuntimeError(f"not supported: {b} % {self}") def __mod__(self, b: Union[Node, int]): """ Modulus of a Node object by another Node or an integer. :param self: The Node object to perform modulus on. :type self: Node :param b: A Node or an integer to calculate the modulus with respect to. :type b: Union[Node, int] :return: A new Node object with the result of the modulus operation. :rtype: Node """ if isinstance(b, Node): if b.__class__ is NumNode: return self % b.b if self == b: return NumNode(0) if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node raise RuntimeError(f"not supported: {self} % {b}") assert b > 0 if b == 1: return NumNode(0) if isinstance(self.max, int) and isinstance(self.min, int): if self.min >= 0 and self.max < b: return self if (self.min // b) == (self.max // b): return self - (b * (self.min // b)) if self.min < 0: return (self - ((self.min // b) * b)) % b return create_node(ModNode(self, b))
[docs] @staticmethod def sum(nodes: List[Node]) -> Node: """ Calculate the sum of nodes with max or min attributes. :param nodes: Input list of nodes to calculate sum for. :type nodes: List[Node] :return: Summed node calculated from input nodes. :rtype: Node """ nodes = [x for x in nodes if x.max or x.min] if not nodes: return NumNode(0) if len(nodes) == 1: return nodes[0] mul_groups: Dict[Node, int] = {} num_node_sum = 0 for node in SumNode(nodes).flat_components: """ Iterate through flat components of SumNode generated from input nodes. :param node: Current element being processed from SumNode.flat_components. :type node: Node """ if node.__class__ is NumNode: num_node_sum += node.b elif node.__class__ is MulNode: mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b else: mul_groups[node] = mul_groups.get(node, 0) + 1 new_nodes = [ MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0 ] if num_node_sum: new_nodes.append(NumNode(num_node_sum)) return ( create_rednode(SumNode, new_nodes) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0) )
[docs] @staticmethod def ands(nodes: List[Node]) -> Node: """ This function takes a list of nodes as input and returns a single node. Attributes: nodes (List[Node]): A list of nodes to be processed. Returns: Node: A single node that represents the logical AND of all input nodes. If any node in the input list is False, it will return a NumNode(0). If all nodes are True, it will return the last node in the list. If there's only one unique node (excluding NumNode(1)), it will return this node. """ if not nodes: return NumNode(1) if len(nodes) == 1: return nodes[0] if any(not x for x in nodes): return NumNode(0) # filter 1s nodes = [x for x in nodes if x.min != x.max] return ( create_rednode(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1)) )
# 4 basic node types
[docs] class Variable(Node): """ Variable class. Attributes: expr (Optional[str]): Optional string representation of the variable. nmin (int): Minimum value for the variable. nmax (int): Maximum value for the variable. """ def __new__(cls, expr: Optional[str], nmin: int, nmax: int): """ Create a new Variable object. Args: cls: Class reference. expr (Optional[str]): Optional string representation of the variable. nmin (int): Minimum value for the variable. nmax (int): Maximum value for the variable. Returns: Variable: A new Variable object if nmin is not equal to nmax, otherwise a NumNode object. Raises: AssertionError: If nmin is less than 0 or greater than nmax. """ assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}" if nmin == nmax: return NumNode(nmin) return super().__new__(cls) def __init__(self, expr: Optional[str], nmin: int, nmax: int): """ Initialize a Variable object. Args: self: Self reference. expr (Optional[str]): Optional string representation of the variable. nmin (int): Minimum value for the variable. nmax (int): Maximum value for the variable. """ self.expr, self.min, self.max = expr, nmin, nmax self._val: Optional[int] = None @property def val(self): """ Get the value of the Variable. Args: self: Self reference. Returns: int: The value of the Variable. Raises: AssertionError: If the Variable is not bound to a value. """ assert ( self._val is not None ), f"Variable isn't bound, can't access val of {self}" return self._val
[docs] def bind(self, val): """ Bind the Variable to a specific value. Args: self: Self reference. val (int): The value to bind the Variable to. Returns: Variable: The bound Variable. Raises: AssertionError: If the Variable is already bound or the value is outside the specified range. """ assert ( self._val is None and self.min <= val <= self.max ), f"cannot bind {val} to {self}" self._val = val return self
[docs] def unbind(self) -> Tuple[Variable, int]: """ Unbind the Variable from its value and return a new Variable with the same properties and its value. Args: self: Self reference. Returns: Tuple[Variable, int]: A tuple containing a new Variable object and the value it was bound to. Raises: AssertionError: If the Variable is not bound. """ assert self.val is not None, f"cannot unbind {self}" return Variable(self.expr, self.min, self.max), self.val
[docs] def vars(self): """ Get a set of the Variable. Args: self: Self reference. Returns: Set[Variable]: A set containing the Variable. """ return {self}
[docs] def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: """ Substitute the Variable with its value or return itself if not in var_vals. Args: self: Self reference. var_vals (Dict[VariableOrNum, Node]): A dictionary mapping Variables to their values. Returns: Node: The substituted Node. """ return var_vals[self] if self in var_vals else self
[docs] class NumNode(Node): """ Node class to represent a number. Attributes: b (int): The numeric value of the node. min (int): Minimum possible value for this node. max (int): Maximum possible value for this node. """ def __init__(self, num: int): """ Initialize a new NumNode with the given number. Args: num (int): The numeric value to assign to this node. Raises: AssertionError: If `num` is not an integer. """ assert isinstance(num, int), f"{num} is not an int" self.b: int = num self.min, self.max = num, num
[docs] def bind(self, val): """ Bind the value of this node to another value. Args: val: The value to bind this node to. Returns: NumNode: This node with its value bound to `val`. Raises: AssertionError: If `self.b` is not equal to `val`. """ assert self.b == val, f"cannot bind {val} to {self}" return self
def __eq__(self, other): """ Compare this node for equality with another object. Args: other: The object to compare against. Returns: bool: `True` if this node is equal to `other`, `False` otherwise. """ return self.b == other def __hash__(self): """ Get the hash value of this node. Returns: int: The hash value of this node. Notes: This method is needed because `__eq__` is overridden. """ return self.hash # needed with __eq__ override
[docs] def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: """ Substitute variables in this node with their corresponding values. Args: var_vals (Dict[VariableOrNum, Node]): A dictionary mapping variables to their substituted values. Returns: NumNode: This node with all its variables substituted by their values. """ return self
[docs] def create_node(ret: Node): """ Create a node object based on the provided parameters. Args: ret (Node): The node to be created. Returns: Union[NumNode, Node]: If min is equal to max, return NumNode with value min. Else, return the inputted node. Raises: AssertionError: If minimum value is greater than maximum value. """ assert ( ret.min <= ret.max ), f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}" if ret.min == ret.max: return NumNode(ret.min) return ret
[docs] class OpNode(Node): """ An operation node class that represents a node with an operation involving two nodes or a node and an integer value. Attributes: a (Node): The first input node. b (Union[Node, int]): The second input node or an integer value. """ def __init__(self, a: Node, b: Union[Node, int]): self.a, self.b = a, b self.min, self.max = self.get_bounds()
[docs] def vars(self): """ Get the set of variables involved in this operation node. Returns: Set[Node]: A set of nodes representing the variables involved in this operation. """ return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
[docs] def get_bounds(self) -> Tuple[int, int]: """ Get the bounds of this operation node. This method should be implemented in any child classes. Raises: NotImplementedError: If this method is not overridden by a child class. """ raise NotImplementedError("must be implemented")
[docs] class LtNode(OpNode): """ This class represents the lower than operator for nodes. Attributes: a (Node): The first node to compare. b (Union[Node, int]): The second node or number to compare. """ def __floordiv__(self, b: Union[Node, int], _=False): """ Check if the floor division of a by b is less than the floor division of b by b. Args: self (LtNode): The instance of LtNode. b (Union[Node, int]): The second node or number to compare. Returns: bool: True if the floor division of a by b is less than the floor division of b by b, False otherwise. """ return (self.a // b) < (self.b // b)
[docs] def get_bounds(self) -> Tuple[int, int]: """ Get the bounds of the LtNode. Args: self (LtNode): The instance of LtNode. Returns: Tuple[int, int]: A tuple representing the lower and upper bounds. """ if isinstance(self.b, int): return ( (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1) ) return ( (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1) )
[docs] def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: """ Substitute variables in the LtNode with their corresponding values. Args: self (LtNode): The instance of LtNode. var_vals (Dict[VariableOrNum, Node]): A dictionary mapping variables to their values. Returns: Node: A new node with the variables substituted by their corresponding values. """ return self.a.substitute(var_vals) < ( self.b if isinstance(self.b, int) else self.b.substitute(var_vals) )
[docs] class MulNode(OpNode): """ This class represents a multiplication node. Attributes: a (Node): The first operand in the multiplication operation. b (Union[Node, int]): The second operand in the multiplication operation. """ def __lt__(self, b: Union[Node, int]): """ This method compares two nodes using the less-than operator. Args: b (Union[Node, int]): The other operand in the comparison operation. Returns: Node: A new node that represents the result of the comparison operation. """ if isinstance(b, Node) or isinstance(self.b, Node) or self.b == -1: return Node.__lt__(self, b) sgn = 1 if self.b > 0 else -1 return Node.__lt__(self.a * sgn, (b + abs(self.b) - 1) // abs(self.b)) def __mul__(self, b: Union[Node, int]): """ This method multiplies the current node with another operand. Args: b (Union[Node, int]): The other operand in the multiplication operation. Returns: Node: A new node that represents the result of the multiplication operation. """ return self.a * (self.b * b) # two muls in one mul def __floordiv__( self, b: Union[Node, int], factoring_allowed=False ): # NOTE: mod negative isn't handled right """ This method performs a floor division operation between the current node and another operand. Args: b (Union[Node, int]): The other operand in the floor division operation. factoring_allowed (bool): Whether or not factoring is allowed. Defaults to False. Returns: Node: A new node that represents the result of the floor division operation. """ if self.b % b == 0: return self.a * (self.b // b) if b % self.b == 0 and self.b > 0: return self.a // (b // self.b) return Node.__floordiv__(self, b, factoring_allowed) def __mod__(self, b: Union[Node, int]): """ This method performs a modulo operation between the current node and another operand. Args: b (Union[Node, int]): The other operand in the modulo operation. Returns: Node: A new node that represents the result of the modulo operation. """ a = self.a * (self.b % b) return Node.__mod__(a, b)
[docs] def get_bounds(self) -> Tuple[int, int]: """ This method returns the lower and upper bounds of the current node. Returns: Tuple[int, int]: A tuple containing the lower and upper bounds. """ return ( (self.a.min * self.b, self.a.max * self.b) if self.b >= 0 else (self.a.max * self.b, self.a.min * self.b) )
[docs] def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: """ This method substitutes variables in the current node with their corresponding values. Args: var_vals (Dict[VariableOrNum, Node]): A dictionary containing variable-value pairs. Returns: Node: A new node with variables replaced by their corresponding values. """ return self.a.substitute(var_vals) * ( self.b if isinstance(self.b, int) else self.b.substitute(var_vals) )
[docs] class DivNode(OpNode): """ The DivNode class is a subclass of OpNode. It defines the behavior of division operations on nodes. Attributes: a (Node): The dividend node. b (int): The divisor integer. """ def __floordiv__(self, b: Union[Node, int], _=False): """ Define the floor division operation between a DivNode and another node or an integer. :param b: The divisor, which can be a Node or an integer. :return: A new Node as result of the floor division operation. """ return self.a // (self.b * b) # two divs is one div
[docs] def get_bounds(self) -> Tuple[int, int]: """ Get the lower and upper bounds of a DivNode. :return: A tuple with the lower and upper bounds. """ assert self.a.min >= 0 and isinstance(self.b, int) return self.a.min // self.b, self.a.max // self.b
[docs] def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: """ Substitute the variables in a DivNode with their corresponding nodes. :param var_vals: A dictionary containing the variables as keys and the corresponding nodes as values. :return: A new DivNode with the variables substituted by their nodes. """ return self.a.substitute(var_vals) // self.b
[docs] class ModNode(OpNode): """ The ModNode class is a subclass of OpNode. It defines the behavior of modulo operations on nodes. Attributes: a (Node): The dividend node. b (int): The divisor integer. """ def __mod__(self, b: Union[Node, int]): """ Define the modulo operation between a ModNode and another node or an integer. :param b: The divisor, which can be a Node or an integer. :return: A new Node as result of the modulo operation. """ if isinstance(b, Node) or isinstance(self.b, Node): return Node.__mod__(self, b) return self.a % b if gcd(self.b, b) == b else Node.__mod__(self, b) def __floordiv__(self, b: Union[Node, int], factoring_allowed=True): """ Define the floor division operation between a ModNode and another node or an integer. :param b: The divisor, which can be a Node or an integer. :return: A new Node as result of the floor division operation. """ if self.b % b == 0: return (self.a // b) % (self.b // b) # put the div inside mod return Node.__floordiv__(self, b, factoring_allowed)
[docs] def get_bounds(self) -> Tuple[int, int]: """ Get the lower and upper bounds of a ModNode. :return: A tuple with the lower and upper bounds. """ assert self.a.min >= 0 and isinstance(self.b, int) return ( (0, self.b - 1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min % self.b >= self.a.max % self.b) else (self.a.min % self.b, self.a.max % self.b) )
[docs] def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: """ Substitute the variables in a ModNode with their corresponding nodes. :param var_vals: A dictionary containing the variables as keys and the corresponding nodes as values. :return: A new ModNode with the variables substituted by their nodes. """ return self.a.substitute(var_vals) % self.b
[docs] class RedNode(Node): """ Base class for all "red" nodes. Attributes: nodes (List[Node]): The list of child nodes. """ def __init__(self, nodes: List[Node]): """ Initializes a new instance of the RedNode class. Parameters: nodes (List[Node]): The list of child nodes. """ self.nodes = nodes
[docs] def vars(self) -> Set[Variable]: """ Returns the union of all variables used in the child nodes. Returns: Set[Variable]: A set containing all unique variables used by the child nodes. """ return set.union(*[x.vars() for x in self.nodes], set())
[docs] class SumNode(RedNode): """ SumNode class for representing sum of nodes. Attributes: __mul__ (method): Distribute multiplication over summation. __floordiv__ (method): Perform floor division with consideration to factoring. """
[docs] @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def __mul__(self, b: Union[Node, int]): """ Multiply a SumNode with a Node or an integer. Args: self (SumNode): The SumNode object. b (Union[Node, int]): A Node or an integer to multiply the SumNode by. Returns: Node: Result of multiplication operation. """ return Node.sum([x * b for x in self.nodes]) # distribute mul into sum
[docs] @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def __floordiv__(self, b: Union[Node, int], factoring_allowed=True): """ Perform floor division on a SumNode. The operation considers factoring and fully dividing nodes before performing the operation. Args: self (SumNode): The SumNode object. b (Union[Node, int]): A Node or an integer to divide the SumNode by. factoring_allowed (bool, optional): Whether factoring is allowed in the division. Defaults to True. Returns: Node: Result of floor division operation. """ fully_divided: List[Node] = [] rest: List[Node] = [] if isinstance(b, SumNode): nu_num = sum( node.b for node in self.flat_components if node.__class__ is NumNode ) de_num = sum( node.b for node in b.flat_components if node.__class__ is NumNode ) if nu_num > 0 and de_num and (d := nu_num // de_num) > 0: return NumNode(d) + (self - b * d) // b if isinstance(b, Node): for x in self.flat_components: if x % b == 0: fully_divided.append(x // b) else: rest.append(x) if (sum_fully_divided := create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b return Node.__floordiv__(self, b, False) if b == 1: return self if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed) fully_divided, rest = [], [] _gcd = b divisor = 1 for x in self.flat_components: if x.__class__ in (NumNode, MulNode): if x.b % b == 0: fully_divided.append(x // b) else: rest.append(x) if isinstance(x.b, int): _gcd = gcd(_gcd, x.b) if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b else: _gcd = 1 else: rest.append(x) _gcd = 1 if _gcd > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(_gcd) // ( b // _gcd ) if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // ( b // divisor ) return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def __mod__(self, b: Union[Node, int]): """ Perform modulo operation with the given node or integer. This method first checks if `b` is an instance of `SumNode`. If so, it calculates the sum of `node.b` for all `NumNode` instances in `self.flat_components` and `b.flat_components`, and performs a modulo operation with the calculated values. Next, if `b` is an instance of `Node` and `(b - self).min > 0`, it returns `self`. Otherwise, it creates a new list of nodes by iterating over the nodes in `self.nodes`. For each node: - If it's a `NumNode`, it appends a new `NumNode` with the modulo operation result to `new_nodes`. - If it's a `MulNode`, it appends the product of `x.a` and `(x.b % b)` to `new_nodes`. - Otherwise, it appends the node itself to `new_nodes`. Finally, it returns the result of the modulo operation between the sum of all nodes in `new_nodes` and `b`. :param b: The node or integer to perform the modulo operation with. :type b: Union[Node, int] :return: The result of the modulo operation. """ if isinstance(b, SumNode): nu_num = sum( node.b for node in self.flat_components if node.__class__ is NumNode ) de_num = sum( node.b for node in b.flat_components if node.__class__ is NumNode ) if nu_num > 0 and de_num and (d := nu_num // de_num) > 0: return (self - b * d) % b if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node new_nodes: List[Node] = [] for x in self.nodes: if x.__class__ is NumNode: new_nodes.append(NumNode(x.b % b)) elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b % b)) else: new_nodes.append(x) return Node.__mod__(Node.sum(new_nodes), b) def __lt__(self, b: Union[Node, int]): """ Perform a less-than comparison between this node and the given node or integer. If `b` is an integer, it first calculates the difference between `b` and the sum of all `NumNode` instances in `self.nodes`. Then, it creates two lists: `muls`, containing nodes that are instances of `MulNode` with positive values, and `others`, containing other nodes. If there are any elements in `muls` and the sum of all `others` is greater than or equal to 0 but less than a calculated greatest common divisor (GCD), it returns True. Otherwise, it performs a less-than comparison between this node (`lhs`) and `b`. :param b: The node or integer to compare with. :type b: Union[Node, int] :return: The result of the less-than comparison. """ lhs: Node = self if isinstance(b, int): new_sum = [] for x in self.nodes: # TODO: should we just force the last one to always be the number if isinstance(x, NumNode): b -= x.b else: new_sum.append(x) lhs = Node.sum(new_sum) nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs] assert all( not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes ), "not supported" muls, others = partition( nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b ) if muls: # NOTE: gcd in python 3.8 takes exactly 2 args mul_gcd = b for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above all_others = Variable.sum(others) if all_others.min >= 0 and all_others.max < mul_gcd: lhs, b = ( Variable.sum([mul // mul_gcd for mul in muls]), b // mul_gcd, ) return Node.__lt__(lhs, b)
[docs] def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: """ Substitute the variables in the expression with their corresponding values. :param var_vals: A dictionary containing variable-value pairs. :type var_vals: Dict[VariableOrNum, Node] :return: The resulting node after substitution. :rtype: Node """ return Variable.sum([node.substitute(var_vals) for node in self.nodes])
@property def flat_components(self): # recursively expand sumnode components """ Get the flattened components of a SumNode. If a component is an instance of SumNode, it will be expanded; otherwise, it remains unchanged. This method is used to flatten a nested sum of nodes into a flat list of nodes. :return: The flattened components of the SumNode. :rtype: List[Node] """ new_nodes = [] for x in self.nodes: new_nodes += x.flat_components if isinstance(x, SumNode) else [x] return new_nodes
[docs] class AndNode(RedNode): """ AndNode class for creating AND operations on a list of nodes. Attributes: RedNode (Class): Parent class for this class. """ def __floordiv__(self, b: Union[Node, int], _=True): """ Overrides the floor division magic method to perform AND operations on nodes. Parameters: self (AndNode): The instance of AndNode. b (Union[Node, int]): The divisor for the floor division operation. _ (bool): Placeholder variable to maintain consistency with parent class method signature. Returns: Node: A new node resulting from performing AND operations on all nodes in self.nodes. """ return Variable.ands([x // b for x in self.nodes])
[docs] def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: """ Substitutes variables in the AND operation with their corresponding values provided in var_vals. Parameters: self (AndNode): The instance of AndNode. var_vals (Dict[VariableOrNum, Node]): A dictionary containing variable-value pairs for substitution. Returns: Node: A new node with variables substituted by their corresponding values. """ subed = [] for node in self.nodes: if not (sub := node.substitute(var_vals)): return NumNode(0) subed.append(sub) return Variable.ands(subed)
[docs] def create_rednode(typ: Type[RedNode], nodes: List[Node]): """ Create a new red node. :param typ: The type of the red node to be created. :type typ: Type[RedNode] :param nodes: The list of nodes to be used for creation. :type nodes: List[Node] :return: A new red node created with given type and nodes. """ ret = typ(nodes) if typ == SumNode: ret.min, ret.max = (sum([x.min for x in nodes]), sum([x.max for x in nodes])) elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes])) return create_node(ret)
[docs] def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: """ Render a symbolic representation of the given node or integer. :param a: The node or integer to be rendered. :type a: Union[Node, int] :param ops: Optional operations for rendering. Default is None. :type ops: Any :param ctx: Optional context for rendering. Default is None. :type ctx: Any :return: A string representation of the given node or integer. :rtype: str """ return str(a) if isinstance(a, int) else a.render(ops, ctx)
[docs] def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int: """ Perform inference on the given node or integer with the provided variable values. :param a: The node or integer to perform inference on. :type a: Union[Node, int] :param var_vals: A dictionary mapping variables to their respective integer values. :type var_vals: Dict[Variable, int] :return: An integer result of the inference. :rtype: int """ if isinstance(a, (int, float)): return a ret = a.substitute({k: NumNode(v) for k, v in var_vals.items()}) assert isinstance( ret, NumNode ), f"sym_infer didn't produce NumNode from {a} with {var_vals}" return ret.b
# symbolic int sint = Union[Node, int] VariableOrNum = Union[Variable, NumNode] render_python: Dict[Type, Callable] = { Variable: lambda self, ops, ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else ( f"Variable('{self.expr}', {self.min}, {self.max})" + (f".bind({self.val})" if self._val is not None else "") if ctx == "REPR" else f"{self.expr}" ), NumNode: lambda self, ops, ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}", MulNode: lambda self, ops, ctx: f"({sym_render(self.b,ops,ctx)}*{self.a.render(ops,ctx)})" if isinstance(self.a, Variable) and isinstance(self.b, Variable) and self.a.expr and self.b.expr and self.b.expr < self.a.expr else f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})", DivNode: lambda self, ops, ctx: f"({self.a.render(ops,ctx)}//{self.b})", ModNode: lambda self, ops, ctx: f"({self.a.render(ops,ctx)}%{self.b})", LtNode: lambda self, ops, ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})", SumNode: lambda self, ops, ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})", AndNode: lambda self, ops, ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})", }