Overview

Mathy's embeddings model takes in a window of observations and outputs a sequence of the same length with fixed-size learned embeddings for each token in the sequence.

Model

The Mathy embeddings model is a stateful model that predicts over sequences. This complicates the process of collecting observations to feed to the model, but allows richer input features than would be available from the simpler state representation.

The model accepts an encoded sequence of tokens and values extracted from the current state's expression tree, and RNN state variables to use wit the recurrent processing layers.

Examples

Observations to Embeddings

You can instantiate a model and produce untrained embeddings:

Open Example In Colab

from mathy import envs
from mathy.agents.base_config import BaseConfig
from mathy.agents.embedding import MathyEmbedding
from mathy.env import MathyEnv
from mathy.state import MathyObservation, observations_to_window


args = BaseConfig()
env: MathyEnv = envs.PolySimplify()
observation: MathyObservation = env.state_to_observation(
    env.get_initial_state()[0], rnn_size=args.lstm_units
)
model = MathyEmbedding(args)
# output shape is: [num_observations, max_nodes_len, embedding_dimensions]
inputs = observations_to_window([observation, observation]).to_inputs()
embeddings = model(inputs)

# We provided two observations in a sequence
assert embeddings.shape[0] == 2
# There are as many outputs as input sequences
assert embeddings.shape[1] == len(observation.nodes)
# Outputs vectors with the provided embedding units
assert embeddings.shape[-1] == args.embedding_units

Access RNN states

The embeddings model is stateful and you can access the current recurrent network hidden and cell states.

Open Example In Colab

import numpy as np

from mathy import envs
from mathy.agents.base_config import BaseConfig
from mathy.agents.embedding import MathyEmbedding
from mathy.env import MathyEnv
from mathy.state import MathyObservation, observations_to_window

args = BaseConfig()
env: MathyEnv = envs.PolySimplify()
observation: MathyObservation = env.state_to_observation(
    env.get_initial_state()[0], rnn_size=args.lstm_units
)
model = MathyEmbedding(args)
inputs = observations_to_window([observation]).to_inputs()

# Expect that the RNN states are zero to begin
assert np.count_nonzero(model.state_h.numpy()) == 0
assert np.count_nonzero(model.state_c.numpy()) == 0

embeddings = model.call(inputs)

# Expect that the RNN states are non-zero
assert np.count_nonzero(model.state_h.numpy()) > 0
assert np.count_nonzero(model.state_c.numpy()) > 0

# You can reset them
model.reset_rnn_state()

# Expect that the RNN states are zero again
assert np.count_nonzero(model.state_h.numpy()) == 0
assert np.count_nonzero(model.state_c.numpy()) == 0

Hidden and Cell states

While the cell state is important to maintain while processing seqeuences, the hidden state is most often used for making predictions. This is because it is considered to contain the most useful representation of the RNN's memory.