-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathblackjack.py
More file actions
337 lines (280 loc) · 11.7 KB
/
blackjack.py
File metadata and controls
337 lines (280 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
#######################################################################
# Copyright (C) #
# 2016-2018 Shangtong Zhang(zhangshangtong.cpp@gmail.com) #
# 2016 Kenta Shimada(hyperkentakun@gmail.com) #
# 2017 Nicky van Foreest(vanforeest@gmail.com) #
# Permission given to modify the code as long as you keep this #
# declaration at the top #
#######################################################################
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
# actions: hit or stand
ACTION_HIT = 0
ACTION_STAND = 1 # "strike" in the book
ACTIONS = [ACTION_HIT, ACTION_STAND]
# policy for player
POLICY_PLAYER = np.zeros(22, dtype=np.int)
for i in range(12, 20):
POLICY_PLAYER[i] = ACTION_HIT
POLICY_PLAYER[20] = ACTION_STAND
POLICY_PLAYER[21] = ACTION_STAND
# function form of target policy of player
def target_policy_player(usable_ace_player, player_sum, dealer_card):
return POLICY_PLAYER[player_sum]
# function form of behavior policy of player
def behavior_policy_player(usable_ace_player, player_sum, dealer_card):
if np.random.binomial(1, 0.5) == 1:
return ACTION_STAND
return ACTION_HIT
# policy for dealer
POLICY_DEALER = np.zeros(22)
for i in range(12, 17):
POLICY_DEALER[i] = ACTION_HIT
for i in range(17, 22):
POLICY_DEALER[i] = ACTION_STAND
# get a new card
def get_card():
card = np.random.randint(1, 14)
card = min(card, 10)
return card
# get the value of a card (11 for ace).
def card_value(card_id):
return 11 if card_id == 1 else card_id
def play(policy_player, initial_state=None, initial_action=None):
"""play a game
Arguments:
policy_player -- specify policy for player
Keyword Arguments:
initial_state -- whether player has a usable Ace, sum of player's cards, one card of dealer
initial_action -- the initial action
Returns:
[ state, -1, player_trajectory]
"""
# player status
# sum of player
player_sum = 0
# trajectory of player
player_trajectory = []
# whether player uses Ace as 11
usable_ace_player = False
# dealer status
dealer_card1 = 0
dealer_card2 = 0
usable_ace_dealer = False
if initial_state is None:
# generate a random initial state
while player_sum < 12:
# if sum of player is less than 12, always hit
card = get_card()
player_sum += card_value(card)
# If the player's sum is larger than 21, he may hold one or two aces.
if player_sum > 21:
assert player_sum == 22
# last card must be ace
player_sum -= 10
else:
usable_ace_player |= (1 == card)
# initialize cards of dealer, suppose dealer will show the first card he gets
dealer_card1 = get_card()
dealer_card2 = get_card()
else:
# use specified initial state
usable_ace_player, player_sum, dealer_card1 = initial_state
dealer_card2 = get_card()
# initial state of the game
state = [usable_ace_player, player_sum, dealer_card1]
# initialize dealer's sum
dealer_sum = card_value(dealer_card1) + card_value(dealer_card2)
usable_ace_dealer = 1 in (dealer_card1, dealer_card2)
# if the dealer's sum is larger than 21, he must hold two aces.
if dealer_sum > 21:
assert dealer_sum == 22
# use one Ace as 1 rather than 11
dealer_sum -= 10
assert dealer_sum <= 21
assert player_sum <= 21
# game starts!
# player's turn
while True:
if initial_action is not None:
action = initial_action
initial_action = None
else:
# get action based on current sum
action = policy_player(usable_ace_player, player_sum, dealer_card1)
# track player's trajectory for importance sampling
player_trajectory.append([(usable_ace_player, player_sum, dealer_card1), action])
if action == ACTION_STAND:
break
# if hit, get new card
card = get_card()
# Keep track of the ace count. the usable_ace_player flag is insufficient alone as it cannot
# distinguish between having one ace or two.
ace_count = int(usable_ace_player)
if card == 1:
ace_count += 1
player_sum += card_value(card)
# If the player has a usable ace, use it as 1 to avoid busting and continue.
while player_sum > 21 and ace_count:
player_sum -= 10
ace_count -= 1
# player busts
if player_sum > 21:
return state, -1, player_trajectory
assert player_sum <= 21
usable_ace_player = (ace_count == 1)
# dealer's turn
while True:
# get action based on current sum
action = POLICY_DEALER[dealer_sum]
if action == ACTION_STAND:
break
# if hit, get a new card
new_card = get_card()
ace_count = int(usable_ace_dealer)
if new_card == 1:
ace_count += 1
dealer_sum += card_value(new_card)
# If the dealer has a usable ace, use it as 1 to avoid busting and continue.
while dealer_sum > 21 and ace_count:
dealer_sum -= 10
ace_count -= 1
# dealer busts
if dealer_sum > 21:
return state, 1, player_trajectory
usable_ace_dealer = (ace_count == 1)
# compare the sum between player and dealer
assert player_sum <= 21 and dealer_sum <= 21
if player_sum > dealer_sum:
return state, 1, player_trajectory
elif player_sum == dealer_sum:
return state, 0, player_trajectory
else:
return state, -1, player_trajectory
def monte_carlo_on_policy(episodes):
"""Monte Carlo Evalution
Arguments:
episodes
Returns:
[state-value function] -- two state-value functions for both casese usable ace and no usable ace
"""
states_usable_ace = np.zeros((10, 10))
# initialze counts to 1 to avoid 0 being divided
states_usable_ace_count = np.ones((10, 10))
states_no_usable_ace = np.zeros((10, 10))
# initialze counts to 1 to avoid 0 being divided
states_no_usable_ace_count = np.ones((10, 10))
# for i in tqdm(range(0, episodes)):
for i in range(0, episodes):
_, reward, player_trajectory = play(target_policy_player)
for (usable_ace, player_sum, dealer_card), _ in player_trajectory:
player_sum -= 12
dealer_card -= 1
if usable_ace:
states_usable_ace_count[player_sum, dealer_card] += 1
states_usable_ace[player_sum, dealer_card] += reward
else:
states_no_usable_ace_count[player_sum, dealer_card] += 1
states_no_usable_ace[player_sum, dealer_card] += reward
return states_usable_ace / states_usable_ace_count, states_no_usable_ace / states_no_usable_ace_count
def monte_carlo_es(episodes):
"""Monte Carlo Exploring Starts
Arguments:
episodes
Returns:
action-state value function
"""
# (playerSum, dealerCard, usableAce, action)
state_action_values = np.zeros((10, 10, 2, 2))
# initialze counts to 1 to avoid division by 0
state_action_pair_count = np.ones((10, 10, 2, 2))
# behavior policy is greedy
def behavior_policy(usable_ace, player_sum, dealer_card):
usable_ace = int(usable_ace)
player_sum -= 12
dealer_card -= 1
# get argmax of the average returns(s, a)
values_ = state_action_values[player_sum, dealer_card, usable_ace, :] / \
state_action_pair_count[player_sum, dealer_card, usable_ace, :]
return np.random.choice([action_ for action_, value_ in enumerate(values_) if value_ == np.max(values_)])
# play for several episodes
for episode in range(episodes):
# for each episode, use a randomly initialized state and action
initial_state = [bool(np.random.choice([0, 1])),
np.random.choice(range(12, 22)),
np.random.choice(range(1, 11))]
initial_action = np.random.choice(ACTIONS)
current_policy = behavior_policy if episode else target_policy_player
_, reward, trajectory = play(current_policy, initial_state, initial_action)
first_visit_check = set()
for (usable_ace, player_sum, dealer_card), action in trajectory:
usable_ace = int(usable_ace)
player_sum -= 12
dealer_card -= 1
state_action = (usable_ace, player_sum, dealer_card, action)
if state_action in first_visit_check:
continue
first_visit_check.add(state_action)
# update values of state-action pairs
state_action_values[player_sum, dealer_card, usable_ace, action] += reward
state_action_pair_count[player_sum, dealer_card, usable_ace, action] += 1
return state_action_values / state_action_pair_count
def montecarlo_evaluation():
"""State value function calculation
"""
states_usable_ace_1, states_no_usable_ace_1 = monte_carlo_on_policy(100)
states_usable_ace_2, states_no_usable_ace_2 = monte_carlo_on_policy(5000)
states = [states_usable_ace_1,
states_usable_ace_2,
states_no_usable_ace_1,
states_no_usable_ace_2]
titles = ['Usable Ace, 10000 Episodes',
'Usable Ace, 500000 Episodes',
'No Usable Ace, 10000 Episodes',
'No Usable Ace, 500000 Episodes']
_, axes = plt.subplots(2, 2, figsize=(40, 30))
plt.subplots_adjust(wspace=0.1, hspace=0.2)
axes = axes.flatten()
for state, title, axis in zip(states, titles, axes):
fig = sns.heatmap(np.flipud(state), cmap="YlGnBu", ax=axis, xticklabels=range(1, 11),
yticklabels=list(reversed(range(12, 22))))
fig.set_ylabel('player sum', fontsize=30)
fig.set_xlabel('dealer showing', fontsize=30)
fig.set_title(title, fontsize=30)
plt.savefig('./images/black_jack_mc_evaluation.png')
plt.close()
def montecarlo_exploring_starts():
"""Monte Carlo solution with exploring starts and greedy behaviour
"""
state_action_values = monte_carlo_es(500000)
state_value_no_usable_ace = np.max(state_action_values[:, :, 0, :], axis=-1)
state_value_usable_ace = np.max(state_action_values[:, :, 1, :], axis=-1)
# get the optimal policy
action_no_usable_ace = np.argmax(state_action_values[:, :, 0, :], axis=-1)
action_usable_ace = np.argmax(state_action_values[:, :, 1, :], axis=-1)
images = [action_usable_ace,
state_value_usable_ace,
action_no_usable_ace,
state_value_no_usable_ace]
titles = ['Optimal policy with usable Ace',
'Optimal value with usable Ace',
'Optimal policy without usable Ace',
'Optimal value without usable Ace']
_, axes = plt.subplots(2, 2, figsize=(40, 30))
plt.subplots_adjust(wspace=0.1, hspace=0.2)
axes = axes.flatten()
for image, title, axis in zip(images, titles, axes):
fig = sns.heatmap(np.flipud(image), cmap="YlGnBu", ax=axis, xticklabels=range(1, 11),
yticklabels=list(reversed(range(12, 22))))
fig.set_ylabel('player sum', fontsize=30)
fig.set_xlabel('dealer showing', fontsize=30)
fig.set_title(title, fontsize=30)
plt.savefig('./images/black_jack_mc_es.png')
plt.close()
if __name__ == '__main__':
montecarlo_evaluation()
montecarlo_exploring_starts()