Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

Overview

tests badge pypi badge docs badge license badge

coax

Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

readthedocs

For the full documentation, including many examples, go to https://coax.readthedocs.io/

Install

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version. To install without CUDA, simply run:

$ pip install jaxlib jax coax --upgrade

If you do require CUDA support, please check out the Installation Guide.

Getting Started

Have a look at the Getting Started page to train your first RL agent.


Comments
  • Quantile Q-Learning Implementation

    Quantile Q-Learning Implementation

    This PR adds a QuantileQ class with function types 3 and 4 that accept a number of quantiles together with the state (and action), as well as a QuantileQLearning class. The QuantileQ function could be merged into the Q class which would simplify the user-facing API. However, some more work needs to be done to incorporate the QuantileQLearning class into the QLearning class. I just wanted to validate that this is the correct approach to take to implement the IQN.

    There still is some documentation for the quantile huber loss missing and the notebooks need to be added and tuned.

    Closes https://github.com/coax-dev/coax/issues/3

    opened by frederikschubert 11
  • Add DeepMind Control Suite Example

    Add DeepMind Control Suite Example

    This PR is a rework of https://github.com/coax-dev/coax/pull/26 and adds an example for using SAC on the Walker.walk task from the DeepMind Control Suite.

    Depends on https://github.com/coax-dev/coax/pull/27 and https://github.com/coax-dev/coax/pull/28

    opened by frederikschubert 6
  • Assertion assert_equal_shape failed for MultiDiscrete action space

    Assertion assert_equal_shape failed for MultiDiscrete action space

    First of all, thank you for developing this package and I really like the modular design. I am a bit new to RL and the JAX ecosystem so my question my be a bit naive. I am currently doing a baseline study with my customized gym environment and VanillaPG but I encounter this bug shown below and I could not figure it out. My understanding is that it is complaining that the shape of log_pi should not be (4,). But I do have a MultiDiscrete action space and its corresponding log_pi should be something like (4,) or (1, 4). I also attached the output when I call coax.Policy.example_data(env) and my policy function definition below if that helps explain the situation.

    So my questions are:

    1. Do you think this error is related to the fact that I have a MultiDiscrete action space?
    2. Did I declare my policy function properly?
    3. Any general ideas on how to debug JAX functions?

    I would appreciate any feedback. Thank you!

    Error message

    ---------------------------------------------------------------------------
    AssertionError                            Traceback (most recent call last)
    Input In [25], in <cell line: 5>()
         13     transition_batch = tracer.pop()
         14     Gn = transition_batch.Rn
    ---> 15     metrics = vanilla_pg.update(transition_batch, Adv=Gn)
         16     env.record_metrics(metrics)
         17 if done:
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:149, in PolicyObjective.update(self, transition_batch, Adv)
        127 def update(self, transition_batch, Adv):
        128     r"""
        129 
        130     Update the model parameters (weights) of the underlying function approximator.
       (...)
        147 
        148     """
    --> 149     grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
        150     if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
        151         raise RuntimeError(f"found nan's in grads: {grads}")
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:218, in PolicyObjective.grads_and_metrics(self, transition_batch, Adv)
        212 if self.REQUIRES_PROPENSITIES and jnp.all(transition_batch.logP == 0):
        213     warnings.warn(
        214         f"In order for {self.__class__.__name__} to work properly, transition_batch.logP "
        215         "should be non-zero. Please sample actions with their propensities: "
        216         "a, logp = pi(s, return_logp=True) and then add logp to your reward tracer, "
        217         "e.g. nstep_tracer.add(s, a, r, done, logp)")
    --> 218 return self._grad_and_metrics_func(
        219     self._pi.params, self._pi.function_state, self.hyperparams, self._pi.rng,
        220     transition_batch, Adv)
    
    File ~/opt/python3.9/site-packages/coax/utils/_jit.py:59, in JittedFunc.__call__(self, *args, **kwargs)
         58 def __call__(self, *args, **kwargs):
    ---> 59     return self._jitted_func(*args, **kwargs)
    
        [... skipping hidden 14 frame]
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:80, in PolicyObjective.__init__.<locals>.grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv)
         77 def grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv):
         78     grads_func = jax.grad(loss_func, has_aux=True)
         79     grads, (metrics, state_new) = \
    ---> 80         grads_func(params, state, hyperparams, rng, transition_batch, Adv)
         82     # add some diagnostics of the gradients
         83     metrics.update(get_grads_diagnostics(grads, f'{self.__class__.__name__}/grads_'))
    
        [... skipping hidden 10 frame]
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:47, in PolicyObjective.__init__.<locals>.loss_func(params, state, hyperparams, rng, transition_batch, Adv)
         45 def loss_func(params, state, hyperparams, rng, transition_batch, Adv):
         46     objective, (dist_params, log_pi, state_new) = \
    ---> 47         self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
         49     # flip sign to turn objective into loss
         50     loss = -objective
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_vanilla_pg.py:52, in VanillaPG.objective_func(self, params, state, hyperparams, rng, transition_batch, Adv)
         49 W = jnp.clip(transition_batch.W, 0.1, 10.)
         51 # some consistency checks
    ---> 52 chex.assert_equal_shape([W, Adv, log_pi])
         53 chex.assert_rank([W, Adv, log_pi], 1)
         54 objective = W * Adv * log_pi
    
    File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:197, in chex_assertion.<locals>._chex_assert_fn(*args, **kwargs)
        195 else:
        196   try:
    --> 197     host_assertion(*args, **kwargs)
        198   except jax.errors.ConcretizationTypeError as exc:
        199     msg = ("Chex assertion detected `ConcretizationTypeError`: it is very "
        200            "likely that it tried to access tensors' values during tracing. "
        201            "Make sure that you defined a jittable version of this Chex "
        202            "assertion.")
    
    File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:157, in make_static_assertion.<locals>._static_assert(custom_message, custom_message_format_vars, include_default_message, exception_type, *args, **kwargs)
        154     custom_message = custom_message.format(*custom_message_format_vars)
        155   error_msg = f"{error_msg} [{custom_message}]"
    --> 157 raise exception_type(error_msg)
    
    AssertionError: [Chex] Assertion assert_equal_shape failed: Arrays have different shapes: [(1,), (1,), (4,)].
    

    Example data

    ExampleData(
      inputs=Inputs(
        args=ArgsType2(
          S={
            'features': array(shape=(1, 1000), dtype=float32, min=0.008, median=2.13, max=2.77)
          is_training=True)
        static_argnums=(
          1))
      output=(
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-2.31, median=0.152, max=0.732)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-1.54, median=-0.138, max=0.994)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-0.984, median=0.0808, max=1.73)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-2.74, median=-0.289, max=1.74)}))
    

    Policy function

    def pi(S, is_training):
        module = CustomizedModule()
        res = tuple([{"logits": item} for item in module(S["features"])])
        return res
    
    question 
    opened by xiangyuy 5
  • 'linear/w' does not match shape

    'linear/w' does not match shape

    I've been starting to learn about RL and have been trying to get coax up and running, but have run into an issue that I'm not sure how to resolve. I'm doing Q-learning on a custom gym environment, and I can run the following pieces successfully:

    q = coax.Q(func_q, env)
    pi = coax.Policy(func_pi, env)
    
    qlearning = coax.td_learning.QLearning(q, pi_targ=pi, optimizer=optax.adam(0.001))
    cache = coax.reward_tracing.NStep(n=1, gamma=0.9)
    

    Additionally, my setup passes the simple checks of:

    data = coax.Q.example_data(env) # Looks good
    ...
    s = env.observation_space.sample()
    a = env.action_space.sample()
    print(q(s,a)) # 0.0
    ...
    a = pi(s)
    print(a) # [0, 0, 0, 0, 0] as I have a MultiDiscrete action space
    

    However, once I get to actually running the training loop:

    for ep in range(50):
      pi.epsilon = 0.1
      s = env.reset()
    
      for t in range(env.maxGuesses):
        a = pi(s)
        s_next, r, done, info = env.step(a)
    
        # update
        cache.add(s, a, r, done)
    
        while cache:
          transition_batch = cache.pop()
          metrics = qlearning.update(transition_batch)
          env.record_metrics(metrics)
    
        if done:
          break
    
        s = s_next
    
        # early stopping
        if env.avg_G > env.reward_threshold:
          break
    

    I get a bunch of errors with the most human-readable of them saying:

    ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    

    By adjusting the parameters of the environment, I can adjust what the numbers that are mismatched are. I can't get them to match and either way that seems like the wrong solution as something more fundamental seems to be the issue.

    For reference, here are my functions for q and pi:

    def func_pi(S, is_training):
      logits = hk.Sequential((
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu,
        hk.Linear(Wordle.wordLength*len(alphabet), w_init=jnp.zeros) # This many possible actions
      ))
      # First, convert to a vector:
      sVec = state_to_vec(S)
    
      # Now get the output:
      logitVec = logits(sVec)
    
      # Now chunk the output into alphabet-sized pieces (definitionally an integral
      # number of them). There will be Wordle.wordLength chunks of this length
      chunks = jnp.split(logitVec, Wordle.wordLength)
    
      # Now format our output array:
      ret = []
      for chunk in chunks:
        ret.append({'logits': jnp.reshape(chunk,(1,len(alphabet)))})
    
      return tuple(ret)
    
    # and for actual state:
    def func_q(S, A, is_training):
      value = hk.Sequential((
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu,
        hk.Linear(30), jax.nn.relu,
        hk.Linear(1, w_init=jnp.zeros), jnp.ravel
      ))
    
      sVec = state_to_vec(S)
      aVec = action_to_vec(A)
    
      X = jnp.concatenate((sVec, aVec))
      return value(X)
    

    Note that state_to_vec(S) and action_to_vec(A) just convert from my internal types to jnp.array's for use with Haiku.

    I'm quite new to coax/JAX/Haiku so it's entirely possible I've set something up wrong. For completeness here's the full text of the error:

    Traceback (most recent call last):
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
        metrics = qlearning.update(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
        grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
        return self._grads_and_metrics_func(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
        out_flat = xla.xla_call(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 596, in process_call
        return primitive.impl(f, *tracers, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 142, in _xla_call_impl
        compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 272, in memoized_fun
        ans = call(fun, *args)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
        return lower_xla_callable(fun, device, backend, name, donated_invars,
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 197, in lower_xla_callable
        jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1623, in trace_to_jaxpr_final
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers_)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
        grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 996, in grad_f_aux
        (_, aux), g = value_and_grad_f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 1067, in value_and_grad_f
        ans, vjp_py, aux = _vjp(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 2478, in _vjp
        out_primal, out_vjp, aux = ad.vjp(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 118, in vjp
        out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 103, in linearize
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 520, in trace_to_jaxpr
        jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
        Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
        out_flat = xla.xla_call(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 324, in process_call
        result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 204, in process_call
        jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 317, in partial_eval
        out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1364, in process_call
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers_)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
        out = f(*args, **kwargs)
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
        return value(X)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
        raise ValueError(
    jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
        metrics = qlearning.update(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
        grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
        return self._grads_and_metrics_func(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
        grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
        Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
        out = f(*args, **kwargs)
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
        return value(X)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
        raise ValueError(
    ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    

    Please let me know if other information would be useful or relevant (or let me know if this isn't actually a coax issue...).

    Thanks for your help and the neat package.

    bug good first issue question 
    opened by bcerjan 5
  • DQN pong example doesn't work off the shelf

    DQN pong example doesn't work off the shelf

    Describe the bug

    Running the DQN example on pong generates the following error when generating a gif:

      File ".../lib/python3.9/site-packages/coax/utils/_misc.py", line 475, in generate_gif
        assert env.render_mode == 'rgb_array', "env.render_mode must be 'rgb_array'"
    

    This is likely due to some recent updates to gym. Currently, on gym==0.26.2 I observe the following:

    import gym
    env = gym.make('PongNoFrameskip-v4', render_mode="rgb_array")
    print(env.render_mode) # prints None
    
    opened by thisiscam 4
  • Add dm_control example for SAC

    Add dm_control example for SAC

    This PR introduces the common squashed normal distribution for the SAC policy on dm_control and provides an example that solves the walker.walk task. Interestingly clipping the actions to the range [-1, 1] diverges. rendering

    @KristianHolsheimer How would you go about changing the installation script for this notebook to add dm_control as a dependency?

    opened by frederikschubert 4
  • Frozen Lake example has an invalid gym signature.

    Frozen Lake example has an invalid gym signature.

    Describe the bug

    The example for Frozen Lake in the main branch of the docs isn't fully updated for the new version of gym's signature.

    ValueError Traceback (most recent call last) in 77 78 a = pi.mode(s) ---> 79 s, r, done, info = env.step(a) 80 81 env.render()

    ValueError: too many values to unpack (expected 4)

    Expected behavior

    Executing the notebook should not result in a ValueError.

    To Reproduce

    Colab notebook to repro the bug:

    - https://colab.research.google.com/...

    Runtime used for this colab notebook: ... (e.g. CPU/GPU/TPU)

    Any.

    Additional context

    Simple fix, happy to contribute a pull request.

    opened by dbleyl 3
  • Incorporating jax.jit into a customer policy

    Incorporating jax.jit into a customer policy

    I'm a bit new to JAX so my question might sound very naive. Suppose we are trying to solve a policy optimization problem through REINFORCE algorithm and suppose we already have our environment at hand (env). We define our customer policy as follows,

    class CustomPolicy(hk.Module):
        def __init__(self, name = None):
            super().__init__(name = name)
        
    
        def __call__(self, x):
            w = hk.get_parameter("w", shape= ... , dtype = x.dtype, init=jnp.zeros)
            # some computation
            return out
    

    Per the documentation, then we define

    def custom_policy(S, is_training=True):
        logits = CustomPolicy()
        return {'logits': logits(S)}
    

    and finally the policy is stated as follows,

    pi = coax.Policy(custom_policy, env)

    I was wondering is there any way to incorporate @jax.jit into this structure to further quicken the performance. Thanks.

    question 
    opened by UweGensheimer 3
  • Multi-Step Entropy Regularization for SAC

    Multi-Step Entropy Regularization for SAC

    • Add record_extra_info flag to the NStep tracer that records the intermediate states in the new extra_info field to TransitionBatch
    • Add support for the NStepEntropyRegularizer in SoftPG

    This PR contains an initial working implementation of the mechanism and sums um the discounted entropy bonuses of the states s_t, s_{t + 1}, ... , s_{t + n - 1} for the soft policy gradient regularization.

    opened by frederikschubert 3
  • Implementation of SAC

    Implementation of SAC

    Since SAC is really similar to TD3, we are able to re-use most of its components. The differences are:

    • The actions to update the q-functions and policy are sampled using the current policy (instead of taking the mode).
    • There is no target policy.
    • The log variance of the policy depends on the state.
    • The policy is entropy regularized.

    The current implementation does not support multi-step td-learning.

    opened by frederikschubert 3
  • AttributeError: module 'jax.api' has no attribute '_jit_is_disabled'

    AttributeError: module 'jax.api' has no attribute '_jit_is_disabled'

    Hi, unsure if this is a due to coax or jax but I get this error when running the pendulum ppo example, dqn runs fine however.

    A similar error I found online recommended changing the version of jaxlib so I went to using the jaxlib version set out in the coax getting started guide but seemed to have no affect jax version = 0.2.13 jaxlib version = 0.1.65 + cuda111 coax version = 0.1.6

    question 
    opened by mmcaulif 3
  • Recurrent Experience Replay

    Recurrent Experience Replay

    Is your feature request related to a problem? Please describe.

    It seems that the implemented replay buffers only operate over transitions, with no ability to operate over entire sequences. This prevents the use of recurrent policies for tackling POMDPs.

    Describe the solution you'd like

    A SequenceReplayBuffer that returns contiguous episodes instead of shuffled transitions.

    Describe alternatives you've considered

    Additional context

    enhancement 
    opened by smorad 3
  • MiniMax Algorithm?

    MiniMax Algorithm?

    How would you implement a minimax q-learner with coax?

    Hi there! I love the package and how accessible it is to relative newbies. The tutorials are pretty great and the accompanying videos are very helpful!

    I was wondering what the best way to implement a minimax algorithm would be, would you recommend using two policies pi1 and pi2? Or is there something better suited for this?

    I'd like to re-implement something like this old blogpost of mine in coax to get a better feel of the library.

    Any help would be greatly appreciated :)

    question 
    opened by flaport 1
  • Convert Numpy Docstrings to Google Style

    Convert Numpy Docstrings to Google Style

    This issue tracks the progress of converting the numpy style docstrings to the more concise Google style.

    • [ ] _core
    • [ ] experience_replay
    • [ ] model_updaters
    • [ ] policy_objectives
    • [ ] proba_dists
    • [ ] reward_tracing
    • [ ] td_learning
    • [ ] utils
    • [ ] value_transforms
    • [ ] wrappers

    This depends on the type annotations https://github.com/coax-dev/coax/issues/13 for easier automatic conversions.

    enhancement 
    opened by frederikschubert 0
  • Add Type Annotations

    Add Type Annotations

    This issue tracks the progress of adding type annotations to coax.

    • [ ] _core
    • [ ] experience_replay
    • [ ] model_updaters
    • [ ] policy_objectives
    • [ ] proba_dists
    • [ ] reward_tracing
    • [ ] td_learning
    • [ ] utils
    • [ ] value_transforms
    • [ ] wrappers

    The types are added by utilising pyannotate and adding the following snippet to the coax._base.TestCase class:

    ...
    @classmethod
        def setUpClass(cls) -> None:
            collect_types.init_types_collection()
            collect_types.start()
    
        @classmethod
        def tearDownClass(cls) -> None:
            collect_types.stop()
            type_replacements = {
                "jaxlib.xla_extension.DeviceArray": "jax.numpy.ndarray",
                "haiku._src.data_structures.FlatMapping": "typing.Mapping",
                "coax._core.policy_test": "gym.Env"
            }
            types_str = collect_types.dumps_stats()
            for inferred_type, replacement in type_replacements.items():
                types_str = types_str.replace(inferred_type, replacement)
            with open(sys.modules[cls.__module__].__file__.replace(".py", "_types.json"), "w") as f:
                f.write(types_str)
    ...
    

    and the types are added automatically

    for t in coax/**/*_test_types.json
    do
        pyannotate --type-info $t -3 coax/* -w
    done
    
    enhancement 
    opened by frederikschubert 0
  • PPOClip grad update seems to cause inf update

    PPOClip grad update seems to cause inf update

    Describe the bug Hey Kris, love your framework! Working with a custom environment, and your discrete action unit test works perfect locally. Don't spend much time investigating this yet, just creating this incase something jumps out at you as the problem. I plan on continuing to debug this issue.

    During the first PPOClip update with the custom gym, the model weights get changed to +/-inf despite a non-infinite grad.

    Expected behavior

    ...
    adv = np.random.rand(32)
    grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
    print("grads", grads)
    print(ppo_clip._pi.params)
    metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
    print(ppo_clip._pi.params)
    

    Results in:

    grads FlatMapping({
      'linear': FlatMapping({
                  'b': DeviceArray([ 0.0477 , -0.02505, -0.05048,  0.02798], dtype=float16),
                  'w': DeviceArray([[ 0.01338 , -0.01921 , -0.01038 ,  0.01622 ],
                                    [ 0.02406 , -0.01683 , -0.02039 ,  0.01316 ],
                                    [ 0.0332  , -0.0227  , -0.03108 ,  0.02061 ],
                                    ...,
                                    [ 0.02452 , -0.00956 , -0.01997 ,  0.005024],
                                    [ 0.010025,  0.001724, -0.03467 ,  0.02295 ],
                                    [ 0.01886 , -0.01413 , -0.01494 ,  0.01022 ]], dtype=float16),
                }),
    
    FlatMapping({
      'linear': FlatMapping({
                  'w': DeviceArray([[-1.0124e-02,  3.4389e-03,  2.9316e-03,  6.5498e-03],
                                    [ 3.3302e-03, -1.7233e-03, -3.0422e-03, -1.8060e-04],
                                    [-2.8908e-05, -3.3131e-03, -6.1073e-03,  6.5804e-03],
                                    ...,
                                    [-2.5597e-03,  7.3471e-03, -3.6221e-03, -5.6801e-03],
                                    [-7.3471e-03, -3.7746e-03,  5.8746e-03,  6.1531e-03],
                                    [-1.1940e-03,  6.9733e-03, -5.0507e-03,  3.4218e-03]],            dtype=float16),
                  'b': DeviceArray([0., 0., 0., 0.], dtype=float16),
                }),
    })
    
    FlatMapping({
      'linear': FlatMapping({
                  'b': DeviceArray([-0.001002,  0.000978,  0.001001, -0.001007], dtype=float16),
                  'w': DeviceArray([[-0.01111  ,  0.004448 ,  0.00386  ,  0.00551  ],
                                    [ 0.002354 , -0.0007563, -0.002048 , -0.001162 ],
                                    [-0.001021 , -0.002335 , -0.005104 ,  0.005558 ],
                                    ...,
                                    [-0.003561 ,  0.008224 , -0.002628 ,       -inf],
                                    [-0.00828  ,       -inf,  0.006874 ,  0.00515  ],
                                    [-0.002203 ,  0.00804  , -0.004086 ,  0.002493 ]],            dtype=float16),
                }),
    

    Here is the full repro script taken from the Pong PPO example and slightly modified, but it won't work because of the custom environment. This is a dummy-example, not the actual policy and value networks that would be used:

    import os
    from luxai2021.env.lux_env import LuxEnvironment, LuxEnvironmentTeam
    from luxai2021.game.game import Game
    from luxai2021.game.actions import *
    from luxai2021.game.constants import LuxMatchConfigs_Default
    
    from luxai2021.env.agent import Agent, AgentWithTeamModel
    import numpy as np
    
    from agent import TeamAgent
    
    # set some env vars
    os.environ.setdefault('JAX_PLATFORM_NAME', 'cpu')     # tell JAX to use GPU
    os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'  # don't use all gpu mem
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'              # tell XLA to be quiet
    
    import gym
    import jax
    import coax
    import haiku as hk
    import jax.numpy as jnp
    from optax import adam
    
    
    # the name of this script
    name = 'ppo'
    
    configs = LuxMatchConfigs_Default
    
    player = TeamAgent(mode="train")
    opponent = Agent()
    
    env = LuxEnvironment(configs=configs,
                                    learning_agent=player,
                                    opponent_agent=opponent)
    env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")
    
    def func_pi(S, is_training):
        n_actions = 4
        out = {'logits': hk.Linear(n_actions)(hk.Flatten()(S)) }
        return out
    
    def func_v(S, is_training):
        h = jnp.ravel(hk.Linear(1)(hk.Flatten()(S)))
        return h
    
    '''
    def func_pi(S, is_training):
        #print(env.action_space.shape)
        n_filters = 5
        n_actions = 4
        n_layers = 3
    
        h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
        for layer in range(n_layers):
            h = jax.nn.relu(h + hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(h))
        
        print('h', type(h), h.shape)
        h_head = (h * S[:,:1]).reshape(h.shape[0], h.shape[1], -1).sum(-1) # torch.Size([1, N_LAYERS])
        h_head_actions = hk.Linear(n_actions)(h_head)
        print('h_head_actions', type(h_head_actions), h_head_actions.shape)
        #print(h_head_actions)
    
        out = {'logits': h_head_actions}
        
        return out
    
    def func_v(S, is_training):
        n_filters = 5
        n_layers = 3
    
        h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
        for layer in range(n_layers):
            h = jax.nn.relu(hk.Conv2D(n_filters, kernel_shape=3, stride=2, data_format='NCHW')(h))
    
        h = hk.Flatten()(h)
        h = jax.nn.relu(hk.Linear(64)(h))
        h = jnp.ravel(hk.Linear(1, w_init=jnp.zeros)(h))
        
        return h
    '''
    
    
    # function approximators
    pi = coax.Policy(func_pi, env)
    v = coax.V(func_v, env)
    
    # target networks
    pi_behavior = pi.copy()
    v_targ = v.copy()
    
    # policy regularizer (avoid premature exploitation)
    entropy = coax.regularizers.EntropyRegularizer(pi, beta=0.001)
    
    # updaters
    simpletd = coax.td_learning.SimpleTD(v, v_targ, optimizer=adam(3e-4))
    ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=entropy, optimizer=adam(3e-4))
    
    # reward tracer and replay buffer
    tracer = coax.reward_tracing.NStep(n=5, gamma=0.99)
    buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)
    
    # run episodes
    max_episode_steps = 400
    while env.T < 3000000:
        s = env.reset()
    
        for t in range(max_episode_steps):
            print(t)
            a, logp = pi_behavior(s, return_logp=True)
            s_next, r, done, info = env.step(a)
    
            # trace rewards and add transition to replay buffer
            tracer.add(s, a, r, done, logp)
            while tracer:
                buffer.add(tracer.pop())
    
            # learn
            if len(buffer) >= buffer.capacity:
                num_batches = int(4 * buffer.capacity / 32)  # 4 epochs per round
                for i in range(num_batches):
                    transition_batch = buffer.sample(32)
                    grads, function_state, metrics, td_error = simpletd.grads_and_metrics(transition_batch)
                    metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)
    
                    
                    adv = np.random.rand(32)
                    grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
                    print("grads", grads)
                    print(ppo_clip._pi.params)
                    metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
                    print(ppo_clip._pi.params)
                    exit()
                    env.record_metrics(metrics_pi)
                    env.record_metrics(metrics_v)
                    
    
                buffer.clear()
    
                # sync target networks
                pi_behavior.soft_update(pi, tau=0.1)
                v_targ.soft_update(v, tau=0.1)
    
            if done:
                break
    
            s = s_next
    
        # generate an animated GIF to see what's going on
        if env.period(name='generate_gif', T_period=10000) and env.T > 50000:
            T = env.T - env.T % 10000  # round to 10000s
            coax.utils.generate_gif(
                env=env, policy=pi, resize_to=(320, 420),
                filepath=f"./data/gifs/{name}/T{T:08d}.gif")
    
    
    opened by glmcdona 3
Releases(v0.1.12)
Official repository of Semantic Image Matting

Semantic Image Matting This is the official repository of Semantic Image Matting (CVPR2021). Overview Natural image matting separates the foreground f

192 Dec 29, 2022
Facial detection, landmark tracking and expression transfer library for Windows, Linux and Mac

Welcome to the CSIRO Face Analysis SDK. Documentation for the SDK can be found in doc/documentation.html. All code in this SDK is provided according t

Luiz Carlos Vieira 7 Jul 16, 2020
Regularizing Generative Adversarial Networks under Limited Data (CVPR 2021)

Regularizing Generative Adversarial Networks under Limited Data [Project Page][Paper] Implementation for our GAN regularization method. The proposed r

Google 148 Nov 18, 2022
Official PyTorch implementation of Joint Object Detection and Multi-Object Tracking with Graph Neural Networks

This is the official PyTorch implementation of our paper: "Joint Object Detection and Multi-Object Tracking with Graph Neural Networks". Our project website and video demos are here.

Richard Wang 443 Dec 06, 2022
Fast, flexible and fun neural networks.

Brainstorm Discontinuation Notice Brainstorm is no longer being maintained, so we recommend using one of the many other,available frameworks, such as

IDSIA 1.3k Nov 21, 2022
SciPy fixes and extensions

scipyx SciPy is large library used everywhere in scientific computing. That's why breaking backwards-compatibility comes as a significant cost and is

Nico Schlömer 16 Jul 17, 2022
MARE - Multi-Attribute Relation Extraction

MARE - Multi-Attribute Relation Extraction Repository for the paper submission: #TODO: insert link, when available Environment Tested with Ubuntu 18.0

0 May 11, 2021
The fastai deep learning library

Welcome to fastai fastai simplifies training fast and accurate neural nets using modern best practices Important: This documentation covers fastai v2,

fast.ai 23.2k Jan 07, 2023
Self-Supervised Image Denoising via Iterative Data Refinement

Self-Supervised Image Denoising via Iterative Data Refinement Yi Zhang1, Dasong Li1, Ka Lung Law2, Xiaogang Wang1, Hongwei Qin2, Hongsheng Li1 1CUHK-S

Zhang Yi 72 Jan 01, 2023
Global Rhythm Style Transfer Without Text Transcriptions

Global Prosody Style Transfer Without Text Transcriptions This repository provides a PyTorch implementation of AutoPST, which enables unsupervised glo

Kaizhi Qian 193 Dec 30, 2022
Face Alignment using python

Face Alignment Face Alignment using python Input Image Aligned Face Aligned Face Aligned Face Input Image Aligned Face Input Image Aligned Face Instal

Sajjad Aemmi 28 Nov 23, 2022
Keras + Hyperopt: A very simple wrapper for convenient hyperparameter optimization

This project is now archived. It's been fun working on it, but it's time for me to move on. Thank you for all the support and feedback over the last c

Max Pumperla 2.1k Jan 03, 2023
Bayesian Meta-Learning Through Variational Gaussian Processes

vmgp This is the repository of Vivek Myers and Nikhil Sardana for our CS 330 final project, Bayesian Meta-Learning Through Variational Gaussian Proces

Vivek Myers 2 Nov 17, 2022
EfficientNetv2 TensorRT int8

EfficientNetv2_TensorRT_int8 EfficientNetv2模型实现来自https://github.com/d-li14/efficientnetv2.pytorch 环境配置 ubuntu:18.04 cuda:11.0 cudnn:8.0 tensorrt:7

34 Apr 24, 2022
Machine Unlearning with SISA

Machine Unlearning with SISA Lucas Bourtoule, Varun Chandrasekaran, Christopher Choquette-Choo, Hengrui Jia, Adelin Travers, Baiwu Zhang, David Lie, N

CleverHans Lab 70 Jan 01, 2023
HandTailor: Towards High-Precision Monocular 3D Hand Recovery

HandTailor This repository is the implementation code and model of the paper "HandTailor: Towards High-Precision Monocular 3D Hand Recovery" (arXiv) G

Lv Jun 113 Jan 06, 2023
PyTorch implementation of PP-LCNet: A Lightweight CPU Convolutional Neural Network

PyTorch implementation of PP-LCNet Reproduction of PP-LCNet architecture as described in PP-LCNet: A Lightweight CPU Convolutional Neural Network by C

Quan Nguyen (Fly) 47 Nov 02, 2022
State-of-the-art language models can match human performance on many tasks

Status: Archive (code is provided as-is, no updates expected) Grade School Math [Blog Post] [Paper] State-of-the-art language models can match human p

OpenAI 259 Jan 08, 2023
Our VMAgent is a platform for exploiting Reinforcement Learning (RL) on Virtual Machine (VM) scheduling tasks.

VMAgent is a platform for exploiting Reinforcement Learning (RL) on Virtual Machine (VM) scheduling tasks. VMAgent is constructed based on one month r

56 Dec 12, 2022
Plugin adapted from Ultralytics to bring YOLOv5 into Napari

napari-yolov5 Plugin adapted from Ultralytics to bring YOLOv5 into Napari. Training and detection can be done using the GUI. Training dataset must be

2 May 05, 2022