Skip to content

Add support for memoroids (linear recurrent models)#91

Open
smorad wants to merge 41 commits intoEdanToledo:mainfrom
smorad:memoroid
Open

Add support for memoroids (linear recurrent models)#91
smorad wants to merge 41 commits intoEdanToledo:mainfrom
smorad:memoroid

Conversation

@smorad
Copy link

@smorad smorad commented Jun 16, 2024

What?

Implement FFM in flax

Why?

For #54

How?

Simply adds a new model

Extra

More commits coming. Just opening this to keep you in the loop.

@smorad smorad changed the title Add support for memoroids (linear recurrent modesl) Add support for memoroids (linear recurrent models) Jun 16, 2024
@smorad
Copy link
Author

smorad commented Jun 16, 2024

Maybe you can help guide me how to integrate this. I was thinking to copy the recurrent PPO script and replace the LSTM with this. I think the biggest issue is I would need to do some plumbing to get the start flag used in memoroids. Basically, start should be 1 at the initial timestep of an episode and zero otherwise.

@EdanToledo
Copy link
Owner

EdanToledo commented Jun 17, 2024

Amazing! This should be quite simple. The timestep.last() function checks if the current time step is the last one. In the current auto-reset API, when an episode finishes and a new one begins, it automatically resets and returns the first observation of the new episode. This means that timestep.last() will indicate the beginning of a new observation (i.e., the "start"), but the reward it returns is from the last time step of the previous episode. You can observe this behavior in the scanned RNN network, where we use this function to reset the hidden state. We want the hidden state to be zeros for the first observation of each new episode. Let me know if that makes sense.

@smorad
Copy link
Author

smorad commented Jun 17, 2024

Sorry, I'm less familiar with the dm_env step format. I know with gym, the done flag denotes that the following observation is the terminal observation. With your format, is the done flag (discount == 0) set at the same time as timestep.last() == True? Obviously it's usually either one or the other that's set, but for the purposes of off-by-one errors are they equivalent?

@EdanToledo
Copy link
Owner

no prob. let me just show how the auto-reset API works with an example to give you a better idea. Imagine a 3-timestep environment where each observation (obs) is simply the timestep number. Here is the rollout trajectory you would get considering you are storing transitions (obs, act, rew, done):

Trajectory Example

Episode 1:

t=0:

  • obs = 0
  • act = any (action taken using obs = 0 as input)
  • rew = reward obtained from taking act using obs = 0 as input
  • done = False

t=1:

  • obs = 1
  • act = any (action taken using obs = 1 as input)
  • rew = reward obtained from taking act using obs = 1 as input
  • done = False

t=2 (Last Timestep, terminal as well as start of episode 2):

  • obs = 0 (we auto-reset to the first observation of the next episode)
  • act = any (this would now be the action taken using obs = 0 as input)
  • rew = this would now be the reward obtained from taking act using obs = 0 as input
  • done = True (indicated by timestep.last or discount == 0)

As you can see in this final transition the obs, act and reward are actually related to the first timestep of the second episode. So when utilising this as a sequence, we would use the done/discount to mask the bootstrap prediction to be zero and we dont use the action and reward from this timestep.

Continuing on we would get:

Episode 2 remainder:

t=3:

  • obs = 1
  • act = any (action taken using obs = 1 as input)
  • rew = reward obtained from taking act using obs = 1 as input
  • done = False

t=4 (Last Timestep, terminal):

  • obs = 0 (auto-reset to the first observation of the next episode)
  • act = any (action taken using obs = 0 as input)
  • rew = reward obtained from taking act using obs = 0 as input
  • done = True (indicated by timestep.last or discount == 0)

Explanation

In this scenario, we never actually see obs = 2, which is the terminal observation. For non-truncating environments, the done flags (timestep.last or discount == 0 in this case) and discounts allow us to mask the bootstrap value when bootstrapping from the second obs = 0 prediction. This ensures it doesn't matter that it’s not the true final observation. This does involve care when utilising these sequences to construct value targets however its not that complicated - usually you just chop off the last timesteps reward and action values when doing things. You can see examples of this in any off policy algorithm that uses sequences (see MPO target value construction).

If you need access to the true final observation, it is available in the extras object. This does involve consideration on whether or not your environment truncates or not. For 1-step transitions like dqn I simply save the o_t and o_t+1 which eliminates any possible issue as i can use true observations. The reason for auto resetting like this is that it removes any dummy rewards and discounts for example when going from terminal obs to starting obs, there wouldn't be a real reward action, or discount.

Let me know if this helps.

Lastly, to just explicitly answer your questions. When timestep.last()==True this means that the observation in that timestep object is the terminal obs however as mentioned above this is not actually returned.

@EdanToledo
Copy link
Owner

EdanToledo commented Jun 23, 2024

I'm just leaving a checklist here of things that need to be done:

  • Explicit catering of batch dimension i.e. not relying on flax.nn.vmap
  • When feeding in a starting carry i would like it to not need a sequence dimension i.e just a batch dimension and feature dimension. It feels more natural in my head to feed it like this - if we decide otherwise we need to change it for the other RNN classes to expect it this way.
  • Add one more type of cell (but only one more since i think we leave others for different PRs) just to check how general the infrastructure is.
  • Test on popgym to ensure it works

@EdanToledo
Copy link
Owner

Additionally, I've officially merged the popgym PR so now we can test on popgym envs easily when we feel ready.

@smorad
Copy link
Author

smorad commented Jun 24, 2024

With respect to removing the sequence dimension in initialize_carry: I think you would be unable to run the cell without a singleton sequence dimension. I can implement a squeeze/unsqueeze in the FFM module, but again, this would mean that you could not run the cell on its own like follows:

h = cell.initial_state()
x = jnp.ones(..)
h, y = cell(h, x)

If you run the following snippet, you will see how the time dimension is present in FFMCell.__call___. Unlike a standard scan, the scanned function operates over more than one element at a time.

import jax
import jax.numpy as jnp

x = jnp.ones((1024, 4)) #[Time, feature]
W = jnp.ones((4, 4))

def ascanf(x, xp):
  print(x.shape, xp.shape)
  return xp @ W

jax.lax.associative_scan(fn=ascanf, elems=x)
(512, 4) (512, 4)
(256, 4) (256, 4)
(128, 4) (128, 4)
(64, 4) (64, 4)
(32, 4) (32, 4)
(16, 4) (16, 4)
(8, 4) (8, 4)
(4, 4) (4, 4)
(2, 4) (2, 4)
(1, 4) (1, 4)
(0, 4) (0, 4)
(1, 4) (1, 4)
(3, 4) (3, 4)
(7, 4) (7, 4)
(15, 4) (15, 4)
(31, 4) (31, 4)
(63, 4) (63, 4)
(127, 4) (127, 4)
(255, 4) (255, 4)
(511, 4) (511, 4)

Let me know how you want to proceed

@EdanToledo
Copy link
Owner

EdanToledo commented Jun 24, 2024

hmmm i see, could we implement the squeeze unsqueeze logic only in the outermost architecture thus still allowing the cell to be run on its own. So basically something like:

class ScannedMemoroid(nn.Module):
    cell: nn.Module

    @nn.compact
    def __call__(self, recurrent_state, inputs):
        ### CHANGE HERE
        recurrent_state = jax.tree.map(lambda x.unsqueeze(0), recurrent_state)
        # Recurrent state should be ((state, timestep), reset)
        # Inputs should be (x, reset)
        x, _ = inputs
        h = self.cell.map_to_h(inputs)
        recurrent_state = recurrent_associative_scan(self.cell, recurrent_state, h)
        # recurrent_state is ((state, timestep), reset)
        out = self.cell.map_from_h(recurrent_state, x)

        # TODO: Remove this when we want to return all recurrent states instead of just the last one
        final_recurrent_state = jax.tree.map(lambda x: x[-1:], recurrent_state)
        return final_recurrent_state, out

    @nn.nowrap
    def initialize_carry(
        self, batch_size: Optional[int] = None, rng: Optional[chex.PRNGKey] = None
    ) -> Carry:
        ### AND CHANGE HERE
        return jax.tree.map(lambda x.squeeze(0), self.cell.initialize_carry(batch_size, rng))

Let me know what you think?

EDIT/UPDATE:

I just made this change and it allows for the system file to be essentially identical to the rec_ppo system which uses normal RNNs. I think this is ideal as it means for all future recurrent algorithms we dont need to differentiate between memoroids or normal rnns. The only difference currently is the use of the nn.vmap, once we code it to explicitly handle batch dimensions then we can use the rec_ppo file exactly and change the network via the config. So the most pressing change would be to do that.

@EdanToledo
Copy link
Owner

@smorad I've now added the explicit expectation of a batch dimension - the network now works with rec_ppo.py natively simply by changing the network conf. For example we can do as follows now: python stoix/systems/ppo/rec_ppo.py network=memoroid

We now need to verify correctness, i do worry that there might be a bug somewhere since the performance on cartpole isn't that good, we get to 200+ quite quickly but it struggles to get to 500. Lastly, i had one concern, i see that the start variable is a part of the carry - is this normal? since ideally we feed the start sequence in via the inputs, not the carry. I'm not sure if it should be there? let me know?

@adzcai
Copy link

adzcai commented Dec 19, 2025

Hey, I was looking into efficient recurrent models recently and this line of exploration seems very interesting. Wondering if there's any updates on this? Happy to help during this winter!

@smorad
Copy link
Author

smorad commented Dec 19, 2025

I eventually put all my recurrent models together into the memax library. We wrote a flax backend for memax together with the goal of using it in stoix. But unfortunately both of us got really busy before we had a chance to combine them. memax provides a single unified API for all recurrent models, so I do not think it would be to difficult to plug it in. If you'd like to give it a shot I can provide some guidance.

@adzcai
Copy link

adzcai commented Dec 19, 2025

That sounds awesome! I'm hoping to use this for a recurrent MuZero implementation and am super excited to hear about memax. I'd love to hear your thoughts before diving in.

@smorad
Copy link
Author

smorad commented Dec 19, 2025

I think it is generally useful for MBRL as the associative recurrent models are significantly faster than something like a GRU. My experience with MBRL is that they require much more training than model-free RL. So having efficient recurrent models is important. It is also very very easy to make subtle mistakes when implementing recurrent models, reducing performance by a significant amount but not so significant that you know something is wrong. For this reason, it is useful to leverage existing/tested code where you can.

You might want to start by looking at this unit test that shows how to instantiate and call all flax recurrent models.

As for integrating it into stoix:

  • Just pass the observation and done flag to the memex modules as a tuple, it handles recurrent state resets automatically
  • You should not pass the output from zero_carry to the model -- you just use it to determine shapes for initial_carry. Multiplicative RNNs need to have identity (not zero) initial states.

P.S. I implemented ~15 models in equinox because that is the framework I like best. The flax backend has fewer implemented. But porting them to flax is really easy because you only need to change 5 lines (mostly in the constructor). It's something an LLM can probably one-shot, but there's been no request for this so I haven't done it yet. The API is the same except for some minor differences with how flax constructors deal with state.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants