Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 5388318

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Adding a first Gym problem for generative RL models.
PiperOrigin-RevId: 179694851
1 parent a66cfaf commit 5388318

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
install_requires=[
2929
'bz2file',
3030
'future',
31+
'gym',
3132
'numpy',
3233
'requests',
3334
'sympy',

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensor2tensor.data_generators import cipher
2626
from tensor2tensor.data_generators import cnn_dailymail
2727
from tensor2tensor.data_generators import desc2code
28+
from tensor2tensor.data_generators import gym
2829
from tensor2tensor.data_generators import ice_parsing
2930
from tensor2tensor.data_generators import image
3031
from tensor2tensor.data_generators import imdb

tensor2tensor/data_generators/gym.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# coding=utf-8
2+
# Copyright 2017 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Data generators for Gym environments."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
24+
# Dependency imports
25+
26+
import gym
27+
28+
from tensor2tensor.data_generators import generator_utils
29+
from tensor2tensor.data_generators import problem
30+
from tensor2tensor.utils import registry
31+
32+
import tensorflow as tf
33+
34+
35+
36+
class GymDiscreteProblem(problem.Problem):
37+
"""Gym environment with discrete actions and rewards."""
38+
39+
def __init__(self, *args, **kwargs):
40+
super(GymDiscreteProblem, self).__init__(*args, **kwargs)
41+
self._env = None
42+
43+
@property
44+
def env_name(self):
45+
# This is the name of the Gym environment for this problem.
46+
raise NotImplementedError()
47+
48+
@property
49+
def env(self):
50+
if self._env is None:
51+
self._env = gym.make(self.env_name)
52+
return self._env
53+
54+
@property
55+
def num_actions(self):
56+
raise NotImplementedError()
57+
58+
@property
59+
def num_rewards(self):
60+
raise NotImplementedError()
61+
62+
@property
63+
def num_steps(self):
64+
raise NotImplementedError()
65+
66+
@property
67+
def num_shards(self):
68+
return 10
69+
70+
@property
71+
def num_dev_shards(self):
72+
return 1
73+
74+
def get_action(self, observation=None):
75+
return self.env.action_space.sample()
76+
77+
def hparams(self, defaults, unused_model_hparams):
78+
p = defaults
79+
p.input_modality = {"inputs": ("image:identity", 256),
80+
"inputs_prev": ("image:identity", 256),
81+
"reward": ("symbol:identity", self.num_rewards),
82+
"action": ("symbol:identity", self.num_actions)}
83+
p.target_modality = ("image:identity", 256)
84+
p.input_space_id = problem.SpaceID.IMAGE
85+
p.target_space_id = problem.SpaceID.IMAGE
86+
87+
def generator(self, data_dir, tmp_dir):
88+
self.env.reset()
89+
action = self.get_action()
90+
prev_observation, observation = None, None
91+
for _ in range(self.num_steps):
92+
prev_prev_observation = prev_observation
93+
prev_observation = observation
94+
observation, reward, done, _ = self.env.step(action)
95+
action = self.get_action(observation)
96+
if done:
97+
self.env.reset()
98+
def flatten(nparray):
99+
flat1 = [x for sublist in nparray.tolist() for x in sublist]
100+
return [x for sublist in flat1 for x in sublist]
101+
if prev_prev_observation is not None:
102+
yield {"inputs_prev": flatten(prev_prev_observation),
103+
"inputs": flatten(prev_observation),
104+
"action": [action],
105+
"done": [done],
106+
"reward": [reward],
107+
"targets": flatten(observation)}
108+
109+
def generate_data(self, data_dir, tmp_dir, task_id=-1):
110+
train_paths = self.training_filepaths(
111+
data_dir, self.num_shards, shuffled=False)
112+
dev_paths = self.dev_filepaths(
113+
data_dir, self.num_dev_shards, shuffled=False)
114+
all_paths = train_paths + dev_paths
115+
generator_utils.generate_files(
116+
self.generator(data_dir, tmp_dir), all_paths)
117+
generator_utils.shuffle_dataset(all_paths)
118+
119+
120+
@registry.register_problem
121+
class GymPongRandom5k(GymDiscreteProblem):
122+
"""Pong game, random actions."""
123+
124+
@property
125+
def env_name(self):
126+
return "Pong-v0"
127+
128+
@property
129+
def num_actions(self):
130+
return 4
131+
132+
@property
133+
def num_rewards(self):
134+
return 2
135+
136+
@property
137+
def num_steps(self):
138+
return 5000

0 commit comments

Comments
 (0)