Mathy uses machine learning (ML) to choose which actions to apply to which nodes in an expression tree.
It picks and takes actions in a loop to accomplish a desired task.
Mathy uses machine learning in a few ways, and has the following features:
- Math Embeddings layer for processing sequences of sequences
- Policy/Value model for estimating which actions to apply to which nodes
- A3C agent for online training of agents in a CPU-friendly setting
- MCTS agent for batch training in a GPU-friendly setting
Mathy processes an input problem by parsing its text into a tree, converting that tree into a sequence features for each node in the tree, combining those features with the current environment state, and embedds them into a variable length sequence of fixed-dimension embeddings.
Text to Tree¶
A problem text is encoded into tokens, then parsed into a tree that preserves the order of operations while removing parentheses and whitespace. Consider the tokens and tree that result from the input:
-3 * (4 + 7)
Observe that the tree representation is more concise than the tokens array because it doesn't have nodes for hierarchical features like parentheses.
Converting text to trees is accomplished with the expression parser:
from typing import List from mathy import ExpressionParser, MathExpression, Token, VariableExpression problem = "4 + 2x" parser = ExpressionParser() tokens: List[Token] = parser.tokenize(problem) expression: MathExpression = parser.parse(problem) assert len(expression.find_type(VariableExpression)) == 1
Tree to List¶
tree list ordering
You might have noticed the features from the previous tree are not expressed in the natural order that we might read them. As observed by Lample and Charton trees must be visited in an order that preserves the order-of-operations, so the model can pick up on the hierarchical features of the input.
For this reason we visit trees in
pre order for serialization.
Converting math expression trees to lists is done with a helper:
from typing import List from mathy import ExpressionParser, MathExpression parser = ExpressionParser() expression: MathExpression = parser.parse("4 + 2x") nodes: List[MathExpression] = expression.to_list() # len([4,+,2,*,x]) assert len(nodes) == 5
Lists to Observations¶
Math turns a list of math expression nodes into a feature list that captures characteristics of the input. Specifically, mathy converts a node list into two lists, one with node types and another with node values:
- The first row is the input token characters stripped of whitespace and parentheses.
- The second row is the sequence of floating point node values for the tree, with each non-constant node represented by a mask value.
- The third row is the node type integer representing the class of the node in the tree.
While feature lists may be directly passable to a ML model, they don't include any information about the state of the problem over time. To work with information over time, mathy agents draw extra information from the environment when building observations. This extra information includes:
- Environment Problem Type: environments all specify an environment namespace that is converted into a pair of hashed string values using different random seeds.
- Episode Relative Time: each observation is able to see a 0-1 floating point value that indicates how close the agent is to running out of moves.
- Current and Historical RNN states: observations include the recurrent neural network (RNN) state of the agent, and a historical average state from all the timesteps in the current episode.
- Valid Action Mask: mathy gives weighted estimates for each action at every node. If there are 5 possible actions, and 10 nodes in the tree, there are up to 50 possible actions to choose from. A same sized (e.g. 50) mask of 0/1 values are provided so that the model can mask out nodes with no valid actions when returning probability distributions.
Mathy has utilities for making the conversion:
from mathy import ( MathyEnv, MathyEnvState, MathyObservation, envs, observations_to_window, ) env: MathyEnv = envs.PolySimplify() state: MathyEnvState = env.get_initial_state() observation: MathyObservation = env.state_to_observation(state, rnn_size=128) # As many nodes as values assert len(observation.nodes) == len(observation.values) # Mask is number of nodes times number of actions assert len(observation.mask) == len(observation.nodes) * env.action_size # RNN states are the same size assert len(observation.rnn_state_h) == 128 assert len(observation.rnn_state_c) == 128 assert len(observation.rnn_history_h) == 128