@@ -65,7 +65,8 @@ def __init__(self,
6565 word_embed_mode = None ,
6666 use_second_place_expert_prob = None ,
6767 use_second_place_expert_prob_temp = None ,
68- top_n_num_experts_per_token = 3 ):
68+ top_n_num_experts_per_token = 3 ,
69+ token_logging = False ):
6970 self ._hparams = HParams (
7071 moe_gating = moe_gating ,
7172 moe_num_experts = num_experts ,
@@ -97,6 +98,7 @@ def __init__(self,
9798 use_second_place_expert_prob_temp ),
9899 moe_top_n_num_experts_per_token = top_n_num_experts_per_token )
99100 self ._activation = activation
101+ self .token_logging = token_logging
100102
101103 def call (self , context , x , losses = None ):
102104 """Call the layer."""
@@ -116,7 +118,13 @@ def call(self, context, x, losses=None):
116118 output_dim = self ._hparams .moe_output_dim
117119 else :
118120 output_dim = context .model .model_dim
119- y , loss = transformer_moe_layer_v1 (
121+ if self .token_logging :
122+ tokens = _detokenize (context .inputs , context .model .vocabulary )
123+ x = mtf .Print (x , [tokens ], "tokens:" , summarize = 1000 )
124+ extras = _windows (context .inputs , context .length_dim )
125+ else :
126+ extras = None
127+ y , loss , extras = transformer_moe_layer_v1 (
120128 x ,
121129 output_dim ,
122130 self ._hparams ,
@@ -127,7 +135,16 @@ def call(self, context, x, losses=None):
127135 nonpadding = context .nonpadding ,
128136 activation = self ._activation ,
129137 num_microbatches = context .num_microbatches ,
130- token_embeddings = context .input_embeddings )
138+ token_embeddings = context .input_embeddings ,
139+ extras = extras )
140+
141+ if extras :
142+ extras = _detokenize (extras , context .model .vocabulary )
143+ experts_dim = mtf .Dimension ("experts" , self ._hparams .moe_num_experts )
144+ extras = mtf .unstack (extras , experts_dim )
145+ for i , t in enumerate (extras ):
146+ y = mtf .Print (y , [t ], "EXPERT %s:" % i , summarize = 1000 )
147+
131148 if context .losses is not None :
132149 context .losses .append (loss )
133150 if not has_length_dim :
@@ -139,6 +156,23 @@ def call(self, context, x, losses=None):
139156 return y
140157
141158
159+ @gin .configurable
160+ def _windows (ids , length_dim , window_start = 0 , window_end = 0 ):
161+ to_stack = []
162+ for offset in range (window_start , window_end + 1 ):
163+ to_stack .append (mtf .shift (ids , - offset , length_dim , wrap = False ))
164+ return mtf .stack (to_stack , "window" , axis = ids .shape .ndims )
165+
166+
167+ def _detokenize (ids , vocabulary ):
168+ return mtf .slicewise (
169+ vocabulary .decode_tf ,
170+ [ids ],
171+ output_shape = mtf .Shape (ids .shape .dims [:- 1 ]),
172+ output_dtype = tf .string ,
173+ splittable_dims = ids .shape .dims [:- 1 ])
174+
175+
142176class MoE2D (transformer .TransformerLayer ):
143177 """Mixture of Experts Layer."""
144178
@@ -202,7 +236,7 @@ def call(self, context, x, losses=None):
202236def transformer_moe_layer_v1 (
203237 inputs , output_dim , hparams , train , variable_dtype ,
204238 layout = None , mesh_shape = None , nonpadding = None , activation = mtf .relu ,
205- num_microbatches = None , token_embeddings = None ):
239+ num_microbatches = None , token_embeddings = None , extras = None ):
206240 """Local mixture of experts that works well on TPU.
207241
208242 Adapted from the paper https://arxiv.org/abs/1701.06538
@@ -281,6 +315,7 @@ def transformer_moe_layer_v1(
281315 [batch_dim(s), length_dim, input_dim]. These are the word embeddings for
282316 that correspond to the inputs. These can optionally be used to make
283317 routing decisions.
318+ extras: a tensor to dispatch (for debugging purposes)
284319
285320 Returns:
286321 outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
@@ -344,6 +379,10 @@ def transformer_moe_layer_v1(
344379 # over which those groups are split.
345380 batch_and_length_dims , input_dim = (orig_inputs .shape .dims [:- 1 ],
346381 orig_inputs .shape .dims [- 1 ])
382+
383+ if extras :
384+ extras_dims = extras .shape .dims [len (batch_and_length_dims ):]
385+
347386 # Hack: we assume that
348387 # "outer_batch" == replication of experts
349388 # mesh_dim_size can be derived from mesh_shape and orig_batch_dim
@@ -381,6 +420,11 @@ def transformer_moe_layer_v1(
381420 token_embeddings = mtf .cast (
382421 mtf .reshape (token_embeddings , moe_input_dims ), inputs .dtype )
383422
423+ if extras :
424+ extras = mtf .reshape (
425+ extras ,
426+ [outer_batch_dim , num_groups_dim , group_size_dim ] + extras_dims )
427+
384428 # Each sequence sends expert_capacity positions to each expert.
385429 if train :
386430 capacity_factor = hparams .moe_capacity_factor_train
@@ -503,6 +547,17 @@ def transformer_moe_layer_v1(
503547 input_dim
504548 ]))
505549
550+ if extras :
551+ extras = mtf .einsum ([extras , mtf .cast (dispatch_tensor , extras .dtype )],
552+ mtf .Shape ([
553+ outer_batch_dim , experts_dim_unsplit ,
554+ num_groups_dim , expert_capacity_dim ] + extras_dims ))
555+ extras = mtf .reshape (
556+ extras ,
557+ mtf .Shape ([
558+ outer_batch_dim , experts_dim , batch_dim_unsplit ,
559+ expert_capacity_dim ] + extras_dims ))
560+
506561 # Now feed the expert inputs through the experts.
507562 h = mtf .layers .dense_product (
508563 expert_inputs ,
@@ -559,10 +614,15 @@ def _compute_output(hidden, layer_name):
559614 k = _compute_output (k_h , layer_name = "k_wo" )
560615 outputs .append (q )
561616 outputs .append (k )
562- return outputs , loss * hparams .moe_loss_coef
617+ return outputs , loss * hparams .moe_loss_coef , None
563618 else :
564619 output = _compute_output (h , layer_name = "wo" )
565- return output , loss * hparams .moe_loss_coef
620+ loss *= hparams .moe_loss_coef
621+
622+ if extras :
623+ return output , loss , extras
624+ else :
625+ return output , loss , None
566626
567627
568628def transformer_moe_layer_v2 (
0 commit comments