AI4Finance-Foundation/ElegantRL

DQN

Opened this issue · 0 comments

class mazeenv(gym.Env):
def init(self):
self.action_space = spaces.Discrete(8)
self.observation_space = spaces.Box(low=0, high=23, shape=(14,), dtype=np.float32)
self.state = None
self.seed()
self.reset()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action):
assert self.action_space.contains(action), "%r (%s) invalid" % (
action,
type(action),
)
aa=[[1,-1],[1,-1],[1,1],[-1,1],[1,0],[-1,0],[0,1],[0,-1]]
length = len(self.state)
n = 2 # 列数
m = length // n # 行数
self.state = np.reshape(self.state, (m, n))
self.ps=self.state[0]
self.goal=self.state[1]
position = copy.deepcopy(self.ps)
position[0] = position[0]+aa[action][0]
position[1] = position[1]+aa[action][1]
self.obs = [[2, 2], [3, 3], [5, 5], [7, 9], [11, 2]]
if position[0]>=23 :
reward = -100
position = copy.deepcopy(self.ps)
done = False
elif position[1]>=23 :
reward = -100
position = copy.deepcopy(self.ps)
done = False
elif position[0]<0 :
reward = -100
position = copy.deepcopy(self.ps)
done = False
elif position[1]<0 :
reward = -100
position = copy.deepcopy(self.ps)
done = False
elif any(np.all(position == o) for o in self.obs):
reward = -100
position = copy.deepcopy(self.ps)
done = False
elif position[0] == self.goal[0] and position[1] == self.goal[1]:
reward = 100
done = True
else:
reward=-1
done = False
position=np.array(position)
self.state[0] = position
self.state = np.reshape(self.state, (-1,))
return np.array(self.state), reward, done, {}
def reset(self):
#self.start = self.np_random.randint(0, 23, size=(1, 2))
self.obs = np.array([[2, 2], [3, 3], [5, 5], [7, 9], [11, 2]])
#self.goal = self.np_random.randint(0, 23, size=(1, 2))
self.start=np.array([[1,1]])
self.goal=np.array([[22,22]])
# 检查 self.state 是否与 obs 中的任何一个元素相等
while any(np.array_equal(self.start, o) for o in self.obs) or self.start[0, 0] < 0 or self.start[0, 1] < 0 or
self.start[0, 0] >= 23 or self.start[0, 1] >= 23:
self.start = np.random.randint(0, 23, size=(1, 2))
while any(np.array_equal(self.goal, o) for o in self.obs) or self.goal[0, 0] < 0 or self.goal[0, 1] < 0 or
self.goal[0, 0] >= 23 or self.goal[0, 1] >= 23 or np.array_equal(self.goal, self.start):
self.goal = np.random.randint(0, 23, size=(1, 2))
self.state = np.concatenate((self.start, self.goal), axis=0)
self.state=np.concatenate((self.state, self.obs), axis=0)
self.state = np.reshape(self.state, (-1,))
return np.array(self.state)

我设计迷宫环境,在训练结束测试中一直找不到终点、。