tristandeleu/pytorch-maml-rl

k-shot testing script

Closed this issue ยท 11 comments

Hi! Thanks for your awesome work!

I am wondering if you have implemented MAML test script like the original paper, where we can test pretrained MAML agents, and plot out k-shot rewards?

I am planning on using your repo for a project, and this function would be highly useful. Thanks!

screenshot from 2018-10-11 01-05-08
Similar to this image from Model-Agnostic Meta Learning paper. Can we compare maml with no-maml using this code base?

Thanks for the kind words!
I did write a test script to test pretrained agents, but this is really ad-hoc. I'll try to clean it up and push it as soon as possible.

Thanks! Appreciate it :)

Here is a small example that I used to verify a one shot adaptation for point mass environment. It is kind of ugly since I implement this just to verify something. Hope it helps! @hzyjerry

import numpy as np
import torch

from maml_rl.envs.point_env import PointEnv
from maml_rl.policies.normal_mlp import NormalMLPPolicy
from maml_rl.baseline import LinearFeatureBaseline
from maml_rl.sampler import BatchSampler
from maml_rl.metalearner import MetaLearner

ITR = 120

#torch.manual_seed(7)

META_POLICY_PATH = "/somepath/policy-{}.pt".format(ITR)
BASELINE_PATH = "/somepath/baseline-{}.pt".format(ITR)

TEST_TASKS = [
    (5., 5.)
]


def load_meta_learner_params(policy_path, baseline_path, env):
    policy_params = torch.load(policy_path)
    baseline_params = torch.load(baseline_path)

    policy = NormalMLPPolicy(
        int(np.prod(env.observation_space.shape)),
        int(np.prod(env.action_space.shape)), 
        hidden_sizes=(100, 100)) # We should actually get this from config
    policy.load_state_dict(policy_params)
    
    baseline = LinearFeatureBaseline(int(np.prod(env.observation_space.shape)))
    baseline.load_state_dict(baseline_params)

    return policy, baseline


def evaluate(env, task, policy, max_path_length=100):
    cum_reward = 0
    t = 0
    env.reset_task(task)
    obs = env.reset()
    for _ in range(max_path_length):
        env.render()
        obs_tensor = torch.from_numpy(obs).to(device='cpu').type(torch.FloatTensor)
        action_tensor = policy(obs_tensor, params=None).sample()
        action = action_tensor.cpu().numpy()
        obs, rew, done, _ = env.step(action)
        cum_reward += rew
        t += 1
        if done:
            break

    print("========EVAL RESULTS=======")
    print("Return: {}, Timesteps:{}".format(cum_reward, t))
    print("===========================")


def main():
    env = PointEnv()
    policy, baseline= load_meta_learner_params(META_POLICY_PATH, BASELINE_PATH, env)
    sampler = BatchSampler(env_name="Point", batch_size=20, real_batch_size=400, num_workers=2)
    learner = MetaLearner(sampler, policy, baseline)

    for task in TEST_TASKS:
        env.reset_task(task)
        
        # Sample a batch of transitions
        sampler.reset_task(task)
        episodes = sampler.sample(policy)
        new_params = learner.adapt(episodes)
        policy.load_state_dict(new_params)
        evaluate(env, task, policy)


if __name__ == '__main__':
    main()

@zhanpenghe thanks so much for sharing!

I tried the same script as @zhanpenghe for HalfCheetahDir-v1 but got horrible results.

import numpy as np
import torch

from maml_rl.envs.mujoco.half_cheetah import HalfCheetahDirEnv
from maml_rl.policies.normal_mlp import NormalMLPPolicy
from maml_rl.baseline import LinearFeatureBaseline
from maml_rl.sampler import BatchSampler
from maml_rl.metalearner import MetaLearner
from gym.utils import seeding

from tensorboardX import SummaryWriter
import numpy as np

ITR = 120

torch.manual_seed(7)
seed = 7


def seed_def(seed=None):
    np_random, seed = seeding.np_random(seed)
    return np_random


META_POLICY_PATH = "somepath/pytorch-maml-rl/saves/maml/policy-120.pt"


def sample_tasks(num_tasks, seed=None):
    np_random = seed_def(seed)
    directions = 2 * np_random.binomial(1, p=0.5, size=(num_tasks,)) - 1
    tasks = [{'direction': direction} for direction in directions]
    return tasks


def load_meta_learner_params(policy_path, env, num_layers=2):
    policy_params = torch.load(policy_path)

    policy = NormalMLPPolicy(
        int(np.prod(env.observation_space.shape)),
        int(np.prod(env.action_space.shape)),
        hidden_sizes=(100,) * num_layers)  # We should actually get this from config
    policy.load_state_dict(policy_params)

    baseline = LinearFeatureBaseline(int(np.prod(env.observation_space.shape)))

    return policy, baseline, policy_params


def evaluate(env, task, policy, policy_params,  max_path_length=100):
    cum_reward = 0
    t = 0
    #env.reset_task(task)
    obs = env.reset()
    for _ in range(max_path_length):
        obs_tensor = torch.from_numpy(obs).to(device='cpu').type(torch.FloatTensor)
        action_tensor = policy(obs_tensor, params=policy_params).sample()
        action = action_tensor.cpu().numpy()
        obs, rew, done, _ = env.step(action)
        cum_reward += rew
        t += 1
        if done:
            break

    print("========EVAL RESULTS=======")
    print("Return: {}, Timesteps:{}".format(cum_reward, t))
    print("===========================")

    return cum_reward


def main():
    env = HalfCheetahDirEnv()
    policy, baseline, params = load_meta_learner_params(META_POLICY_PATH, env, )
    sampler = BatchSampler(env_name='HalfCheetahDir-v1',
                           batch_size=20)
    learner = MetaLearner(sampler, policy, baseline)
    writer = SummaryWriter()

    num_updates = 3

    cum_reward = 0

    TEST_TASKS = sample_tasks(2, seed=None)

    for i, task in enumerate(TEST_TASKS):
        print(task)
        env.reset_task(task)
        # Sample a batch of transitions
        sampler.reset_task(task)
        episodes = sampler.sample(policy)
        for u in range(num_updates):
            new_params = learner.adapt(episodes)
            policy.load_state_dict(new_params)
            cum_reward = evaluate(env, task, policy, params)
            writer.add_scalar('data/cumm_reward', cum_reward, i)


if __name__ == '__main__':
    main()

So I made some tweaks, but for some reason, the network performs horribly on the new task with 3 4 meta updates. Any ideas?

@navneet-nmk, could you share us what results you got for HalfCheetahDir-v1? Thanks!

Same as @navneet-nmk , tried the script and didn't see a good result. Anyone knows what happens?
Thanks!

Has anyone implemented the test code?

Hi everyone, sorry for the lack of update on this issue. I have now updated the repo with a new version of the code which includes a script to test the policy. The script should work with policies trained with the previous version of the code (although the config might be different), but this is not guaranteed.