-
Notifications
You must be signed in to change notification settings - Fork 1
/
vis_simple.py
69 lines (48 loc) · 1.81 KB
/
vis_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import yaml
import pickle
import argparse
import numpy as np
import torch as T
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from datetime import datetime
from collections import namedtuple
from Harlow_Simple.harlow import HarlowSimple
from models.a3c_lstm_simple import A3C_LSTM
def run_episode(agent, env, device="cpu"):
agent.eval()
done = False
state = env.reset()
p_action, p_reward = [0,0,0], 0
ht, ct = agent.get_init_states(device)
while not done:
logit, _, (ht, ct) = agent(
T.tensor([state]).float().to(device), (
T.tensor([p_action]).float().to(device),
T.tensor([[p_reward]]).float().to(device)),
(ht, ct)
)
action = T.argmax(F.softmax(logit, dim=-1), -1)
state, reward, done, _ = env.step(action)
p_action = np.eye(env.n_actions)[action]
p_reward = reward
env.reset()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Paramaters')
parser.add_argument('-c', '--config', type=str, default="Harlow_Simple/config.yaml", help='path of config file')
args = parser.parse_args()
with open(args.config, 'r', encoding="utf-8") as fin:
config = yaml.load(fin, Loader=yaml.FullLoader)
load_path = config["load-path"]
save_path = os.path.join(config["save-path"], config["run-title"], config["run-title"]+"_{epi:04d}.gif")
agent = A3C_LSTM(
config["task"]["input-dim"],
config["agent"]["mem-units"],
config["task"]["num-actions"],
)
agent.load_state_dict(T.load(load_path)["state_dict"])
env = HarlowSimple(visualize=True, save_interval=1, save_path=save_path)
run_episode(agent, env)