Skip to content



Mathy uses a swarm planning algorithm to choose which actions to apply to which nodes in an expression tree.

It picks and takes actions in a loop to accomplish a desired task.

Specifically, Mathy uses Fragile swarm-planning to choose actions in built-in and user-defined reinforcement learning environments.

Text Preprocessing

Mathy processes an input problem by parsing its text into a tree, converting that into a sequence of features for each node in the tree, concatenating those features with the current environment state, time, type, and valid action mask.

Text to Tree

A problem text is encoded into tokens, then parsed into a tree that preserves the order of operations while removing parentheses and whitespace. Consider the tokens and tree that result from the input: -3 * (4 + 7)


- 8 3 1 * 16 ( 256 4 1 + 4 7 1 ) 512 8192


-3 4 7 + *

Please observe that the tree representation is more concise than the tokens array because it doesn't have nodes for hierarchical features like parentheses.

Converting text to trees is accomplished with the expression parser:

Open Example In Colab

from typing import List

from mathy_core import ExpressionParser, MathExpression, Token, VariableExpression

problem = "4 + 2x"
parser = ExpressionParser()
tokens: List[Token] = parser.tokenize(problem)
expression: MathExpression = parser.parse(problem)
assert len(expression.find_type(VariableExpression)) == 1

Tree to List

Rather than expose tree structures to environments, we traverse them to produce node/value lists.

tree list ordering

You might have noticed that the previous tree features are not expressed in the natural order that we might read them. As observed by Lample and Charton trees must be visited in an order that preserves the order-of-operations so that the model can pick up on the hierarchical features of the input.

For this reason, we visit trees in pre order for serialization.

Converting math expression trees to lists is done with a helper:

Open Example In Colab

from typing import List

from mathy_core import ExpressionParser, MathExpression

parser = ExpressionParser()
expression: MathExpression = parser.parse("4 + 2x")
nodes: List[MathExpression] = expression.to_list()
# len([4,+,2,*,x])
assert len(nodes) == 5

Lists to Observations

Mathy turns a list of math expression nodes into a feature list that captures the input characteristics. Specifically, mathy converts a node list into two lists, one with node types and another with node values:

* 0.3 0.2857142857142857 -3 0.0 1.0 + 0.3 0.0 4 0.7 1.0 7 1.0 1.0

  • The first row is the input token characters stripped of whitespace and parentheses.
  • The second row is the sequence of floating-point node values for the tree, with each non-constant node represented by a mask value.
  • The third row is the node type integer representing the node's class in the tree.

While feature lists may be directly passable to an ML model, they don't include any information about the problem's state over time. To work with information over time, mathy agents draw extra information from the environment when building observations. This extra information includes:

  • Environment Problem Type: environments all specify an environment namespace that is converted into a pair of hashed string values using different random seeds.
  • Episode Relative Time: each observation can see a 0-1 floating-point value that indicates how close the agent is to running out of moves.
  • Valid Action Mask: mathy gives weighted estimates for each action at every node. If there are five possible actions and ten nodes in the tree, there are up to 50 possible actions. A same-sized (e.g., 50) mask of 0/1 values are provided so that the model can mask out nodes with no valid actions when returning probability distributions.

Mathy has utilities for making the conversion:

Open Example In Colab

from mathy_envs import MathyEnv, MathyEnvState, MathyObservation, envs

env: MathyEnv = envs.PolySimplify()
state: MathyEnvState = env.get_initial_state()[0]
observation: MathyObservation = env.state_to_observation(state)

# As many nodes as values
assert len(observation.nodes) == len(observation.values)
# Mask is a binary validity mask of size (num_rules, num_nodes)
assert len(observation.mask) == len(env.rules)
assert len(observation.mask[0]) == len(observation.nodes)

Last update: November 22, 2020