Policy/Value Model

Overview

Mathy uses a model that predicts which action to take in an environment, and the scalar value of the current state.

Model

Mathy's policy/value model takes in a window of observations and outputs a weighted distribution over all the possible actions and value estimates for each observation.

Examples

Call the Model

The simplest thing to do is to load a blank model and pass some data through it. This gives us a sense of how things works:

Open Example In Colab

import tensorflow as tf

from mathy import envs
from mathy.agents.base_config import BaseConfig
from mathy.agents.policy_value_model import PolicyValueModel
from mathy.env import MathyEnv
from mathy.state import MathyObservation, observations_to_window

args = BaseConfig()
env: MathyEnv = envs.PolySimplify()
observation: MathyObservation = env.state_to_observation(env.get_initial_state()[0])
model = PolicyValueModel(args, predictions=env.action_size)
inputs = observations_to_window([observation]).to_inputs()
# predict_next only returns a policy for the last observation
# in the sequence, and applies masking and softmax to the output
policy, value = model.predict_next(inputs)

# The policy is a 1D array of size (actions * num_nodes)
assert policy.shape.rank == 1
assert policy.shape == (env.action_size * len(observation.nodes),)

# There should be one floating point output Value
assert value.shape.rank == 0
assert isinstance(float(value.numpy()), float)

Save Model with Optimizer

Mathy's optimizer is stateful and so it has to be saved alongside the model if we want to pause and continue training later. To help with this Mathy has a function get_or_create_policy_model.

The helper function handles:

  • Creating a folder if needed to store the model and related files
  • Saving the agent hyperparameters used for training the model model.config.json
  • Initializing and sanity checking the model by compiling and calling it with a random observation

Open Example In Colab

#!pip install gym
import shutil
import tempfile

from mathy.agents.a3c import A3CAgent, A3CConfig
from mathy.agents.policy_value_model import PolicyValueModel, get_or_create_policy_model
from mathy.cli import setup_tf_env
from mathy.envs import PolySimplify

model_folder = tempfile.mkdtemp()
setup_tf_env()

args = A3CConfig(
    max_eps=3,
    verbose=True,
    topics=["poly"],
    model_dir=model_folder,
    update_gradients_every=4,
    num_workers=1,
    units=4,
    embedding_units=4,
    lstm_units=4,
    print_training=True,
)
instance = A3CAgent(args)
instance.train()
# Load the model back in
model_two = get_or_create_policy_model(
    args=args, predictions=PolySimplify().action_size, is_main=True
)
# Comment this out to keep your model
shutil.rmtree(model_folder)