Swarm Planning Solver ¶
This notebook is built using
mathy.fragile
module for swarm planning to determine which actions to take. The research and implementation come from @Guillemdb and @sergio-hcsoft. They're both amazing 🙇
Training a machine learning model is sometimes inconvenient and time-consuming, especially when working on a new problem type or set of rules.
So, what do we do in these cases? We use Mathy's built-in swarm planning algorithm, of course!
When you're developing a rule or environment, you'd often like to know how well an average agent can be expected to perform on this task without fully-training a model each time you change the code.
Let's look together at how we can use mathy.fragile
to implement an agent that selects winning actions without any training while still showing its work step-by-step.
!pip install mathy mathy_core mathy_envs
Fractal Monte Carlo¶
The Fractal Monte Carlo (FMC) algorithm we use comes from mathy.fragile
module and uses a swarm of walkers to explore your environment and find optimal paths to the solution. We'll use it with mathy_envs
, to solve math problems step-by-step.
By the time you're done with this notebook, you should understand how FMC, through its unique path-search capabilities, interfaces with Mathy to tackle Mathy's large, sparse action spaces.
from typing import Any, Dict, Optional, Tuple, Union, cast
import numpy as np
from mathy_core import MathTypeKeysMax
from mathy_envs import EnvRewards, MathyEnv, MathyEnvState
from mathy_envs.gym import MathyGymEnv
from mathy.fragile.env import DiscreteEnv
from mathy.fragile.models import DiscreteModel
from mathy.fragile.states import StatesEnv, StatesModel, StatesWalkers
from mathy.fragile.swarm import Swarm
from mathy.fragile.distributed_env import ParallelEnv
# Use multiprocessing to speed up the swarm
use_mp: bool = True
# The number of walkers to use in the swarm
n_walkers: int = 512
# The number of iterations to run the swarm for
max_iters: int = 100
Action Selection¶
Fragile FMC uses a "Model" class for performing action selection for the walkers in the swarm. Each walker in n_walkers
needs to select actions, so we do it across large batches here.
To aid in navigating the sizeable sparse action space, we'll use the action mask included in mathy observations (by default) to select only valid actions at each swarm step.
class DiscreteMasked(DiscreteModel):
def sample(
self,
batch_size: int,
model_states: StatesModel,
env_states: StatesEnv,
walkers_states: StatesWalkers,
**kwargs
) -> StatesModel:
if env_states is not None:
# Each state is a vstack([node_ids, mask]) and we only want the mask.
masks = env_states.observs[:, -self.n_actions :]
axis = 1
# Select a random action using the mask to filter out invalid actions
random_values = np.expand_dims(
self.random_state.rand(masks.shape[1 - axis]), axis=axis
)
actions = (masks.cumsum(axis=axis) > random_values).argmax(axis=axis)
else:
actions = self.random_state.randint(0, self.n_actions, size=batch_size)
return self.update_states_with_critic(
actions=actions,
model_states=model_states,
batch_size=batch_size,
**kwargs,
)
Planning Wrapper¶
Because FMC uses a swarm of many workers, it's vastly more efficient if you can interact with them in batches, similar to how we did above with the action selection.
To support batch environments with stepping, etc, we'll implement a wrapper environment that supports the expected plangym interface and creates an internal mathy environment. The class will also implement the step_batch
method for simultaneously stepping a batch of environments.
import gymnasium as gym
from gymnasium import spaces
class PlanningEnvironment:
"""Fragile Environment for solving Mathy problems."""
problem: Optional[str]
@property
def unwrapped(self) -> MathyGymEnv:
return cast(MathyGymEnv, self._env.unwrapped)
def __init__(
self,
name: str,
environment: str = "poly",
difficulty: str = "normal",
problem: Optional[str] = None,
max_steps: int = 64,
**kwargs,
):
self._env = gym.make(
f"mathy-{environment}-{difficulty}-v0",
invalid_action_response="terminal",
env_problem=problem,
mask_as_probabilities=True,
**kwargs,
)
self.observation_space = spaces.Box(
low=0,
high=MathTypeKeysMax,
shape=(256, 256, 1),
dtype=np.uint8,
)
self.action_space = spaces.Discrete(self._env.unwrapped.action_size)
self.problem = problem
self.max_steps = max_steps
self._env.reset()
def get_state(self) -> np.ndarray:
assert self.unwrapped.state is not None, "env required to get_state"
return self.unwrapped.state.to_np(2048)
def set_state(self, state: np.ndarray):
assert self.unwrapped is not None, "env required to set_state"
self.unwrapped.state = MathyEnvState.from_np(state)
return state
def step(
self, action: int, state: np.ndarray = None
) -> Tuple[np.ndarray, np.ndarray, Any, bool, Dict[str, object]]:
assert self._env is not None, "env required to step"
assert state is not None, "only works with state stepping"
self.set_state(state)
obs, reward, _, _, info = self._env.step(action)
oob = not info.get("valid", False)
new_state = self.get_state()
return new_state, obs, reward, oob, info
def step_batch(
self,
actions,
states: Optional[Any] = None,
n_repeat_action: Optional[Union[int, np.ndarray]] = None,
) -> tuple:
data = [self.step(action, state) for action, state in zip(actions, states)]
new_states, observs, rewards, terminals, infos = [], [], [], [], []
for d in data:
new_state, obs, _reward, end, info = d
new_states.append(new_state)
observs.append(obs)
rewards.append(_reward)
terminals.append(end)
infos.append(info)
return new_states, observs, rewards, terminals, infos
def reset(self, batch_size: int = 1):
assert self._env is not None, "env required to reset"
obs, info = self._env.reset()
return self.get_state(), obs
FMC Environment¶
To use the batch planning environment, we must create a Mathy environment that extends the discrete environment exposed by Fragile.
There's not much special here; we instantiate the planning environment for use in the base class and implement the make_transition
function to set terminal states according to the mathy_envs "done" property.
class FMCEnvironment(DiscreteEnv):
"""Fragile FMC Environment for solving Mathy problems."""
def __init__(
self,
name: str,
environment: str = "poly",
difficulty: str = "easy",
problem: Optional[str] = None,
max_steps: int = 64,
**kwargs,
):
self._env = PlanningEnvironment(
name=name,
environment=environment,
difficulty=difficulty,
problem=problem,
max_steps=max_steps,
**kwargs,
)
self._n_actions = self._env.action_space.n
super(DiscreteEnv, self).__init__(
states_shape=self._env.get_state().shape,
observs_shape=self._env.observation_space.shape,
)
def make_transitions(
self, states: np.ndarray, actions: np.ndarray, dt: Union[np.ndarray, int]
) -> Dict[str, np.ndarray]:
new_states, observs, rewards, oobs, infos = self._env.step_batch(
actions=actions, states=states
)
terminals = [inf.get("done", False) for inf in infos]
data = {
"states": np.array(new_states),
"observs": np.array(observs),
"rewards": np.array(rewards),
"oobs": np.array(oobs),
"terminals": np.array(terminals),
}
return data
Swarm Solver¶
Now that we've set up a masked action selector and a batch-capable environment for planning with many walkers, we can put it all together and use the power of the Fractal Monte Carlo swarm to find a path to our desired solution.
def swarm_solve(problem: str, max_steps: int = 256, silent: bool = False) -> None:
def mathy_dist(x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Calculate Euclidean distance between two arrays."""
return np.linalg.norm(x - y, axis=1)
def env_callable():
"""Environment setup for solving the given problem."""
return FMCEnvironment(
name="mathy_v0",
problem=problem,
repeat_problem=True,
max_steps=max_steps,
)
mathy_env: MathyEnv = env_callable()._env.unwrapped.mathy
if use_mp:
env_callable = ParallelEnv(env_callable=env_callable)
swarm = Swarm(
model=lambda env: DiscreteMasked(env=env),
env=env_callable,
reward_limit=EnvRewards.WIN,
n_walkers=n_walkers,
max_epochs=max_iters,
reward_scale=1,
distance_scale=3,
distance_function=mathy_dist,
show_pbar=False,
)
if not silent:
print(f"Solving {problem} ...\n")
swarm.run()
if not silent:
if swarm.walkers.best_reward > EnvRewards.WIN:
last_state = MathyEnvState.from_np(swarm.walkers.states.best_state)
mathy_env.print_history(last_state)
print(f"Solved! {problem} = {last_state.agent.problem}")
else:
print("Failed to find a solution.")
print(f"\nBest reward: {swarm.walkers.best_reward}\n\n")
Evaluation¶
So, after all that work, we can finally test and see how well the swarm can solve the problems we input. Let's give it a go!
It's essential to remember that our chosen environment only has a specific set of rules, so problems that rely on other rules to solve will not work here.
Let's recall which rules are available in the environment and solve a few problems:
env = FMCEnvironment(name="mathy_v0")
rules = "\n\t".join([e.name for e in env._env.unwrapped.mathy.rules])
print(f"Environment rules:\n\t{rules}\n")
swarm_solve("2x * x + 3j^7 + (1.9x^2 + -8y)")
swarm_solve("4x + 2y + 3j^7 + 1.9x + -8y")
Environment rules:
Constant Arithmetic
Commutative Swap
Distributive Multiply
Distributive Factoring
Associative Group
Variable Multiplication
Restate Subtraction
Solving 2x * x + 3j^7 + (1.9x^2 + -8y) ...
initial-state(-1) | 2x * x + 3j^7 + (1.9x^2 + -8y)
variable multiplication(3) | 2x^(1 + 1) + 3j^7 + (1.9x^2 + -8y)
associative group(19) | 2x^(1 + 1) + 3j^7 + 1.9x^2 + -8y
constant arithmetic(5) | 2x^2 + 3j^7 + 1.9x^2 + -8y
commutative swap(5) | 3j^7 + 2x^2 + 1.9x^2 + -8y
distributive factoring(11) | 3j^7 + (2 + 1.9) * x^2 + -8y
constant arithmetic(7) | 3j^7 + 3.9x^2 + -8y
Solved! 2x * x + 3j^7 + (1.9x^2 + -8y) = 3j^7 + 3.9x^2 + -8y
Best reward: 1.34333336353302
Solving 4x + 2y + 3j^7 + 1.9x + -8y ...
initial-state(-1) | 4x + 2y + 3j^7 + 1.9x + -8y
commutative swap(13) | 4x + 2y + 1.9x + 3j^7 + -8y
commutative swap(7) | 4x + 1.9x + 2y + 3j^7 + -8y
associative group(7) | 4x + 1.9x + (2y + 3j^7) + -8y
associative group(3) | 4x + (1.9x + (2y + 3j^7)) + -8y
distributive factoring(3) | (4 + 1.9) * x + (2y + 3j^7) + -8y
commutative swap(15) | (4 + 1.9) * x + -8y + (2y + 3j^7)
constant arithmetic(1) | 5.9x + -8y + (2y + 3j^7)
distributive factoring(7) | 5.9x + (-8 + 2) * y + 3j^7
constant arithmetic(5) | 5.9x + -6y + 3j^7
Solved! 4x + 2y + 3j^7 + 1.9x + -8y = 5.9x + -6y + 3j^7
Best reward: 1.202222228050232
Conclusion¶
If you're reading this, you either skipped ahead or you're an absolute legend! Either way, congrats! ðŸ«
I hope you now better understand how planning algorithms can integrate with Mathy to facilitate complex environment solving without trained models.