diff --git a/easy_rec/python/compat/optimizers.py b/easy_rec/python/compat/optimizers.py index d31a4cd41..0fc98c984 100644 --- a/easy_rec/python/compat/optimizers.py +++ b/easy_rec/python/compat/optimizers.py @@ -362,51 +362,53 @@ def optimize_loss(loss, # clip_ops.global_norm(list(zip(*gradients))[0])) # Optionally clip gradients by global norm. - if isinstance(clip_gradients, float): - # gradients = _clip_gradients_by_norm(gradients, clip_gradients) - sparse_norm, dense_norm, grad_norm = _get_grad_norm( - gradients, embedding_parallel) - summary.scalar('global_norm/sparse_grad', sparse_norm) - summary.scalar('global_norm/dense_grad', dense_norm) - summary.scalar('global_norm/gradient_norm', grad_norm) - grads = [x[0] for x in gradients] - vars = [x[1] for x in gradients] - clipped_grads, _ = clip_ops.clip_by_global_norm( - grads, clip_gradients, use_norm=grad_norm) - gradients = list(zip(clipped_grads, vars)) - elif callable(clip_gradients): - gradients = clip_gradients(gradients) - elif clip_gradients is not None: - raise ValueError('Unknown type %s for clip_gradients' % - type(clip_gradients)) + if not embedding_parallel: + if isinstance(clip_gradients, float): + # gradients = _clip_gradients_by_norm(gradients, clip_gradients) + sparse_norm, dense_norm, grad_norm = _get_grad_norm( + gradients, embedding_parallel) + summary.scalar('global_norm/sparse_grad', sparse_norm) + summary.scalar('global_norm/dense_grad', dense_norm) + summary.scalar('global_norm/gradient_norm', grad_norm) + grads = [x[0] for x in gradients] + vars = [x[1] for x in gradients] + clipped_grads, _ = clip_ops.clip_by_global_norm( + grads, clip_gradients, use_norm=grad_norm) + gradients = list(zip(clipped_grads, vars)) + elif callable(clip_gradients): + gradients = clip_gradients(gradients) + elif clip_gradients is not None: + raise ValueError('Unknown type %s for clip_gradients' % + type(clip_gradients)) # Add scalar summary for loss. if 'loss' in summaries: summary.scalar('loss', loss) # Add histograms for variables, gradients and gradient norms. + + for gradient, variable in gradients: + if isinstance(gradient, indexed_slices.IndexedSlices): + grad_values = gradient.values + else: + grad_values = gradient + + if grad_values is not None: + var_name = variable.name.replace(':', '_') + if 'gradients' in summaries: + summary.histogram('gradients/%s' % var_name, grad_values) + if 'gradient_norm' in summaries: + summary.scalar('gradient_norm/%s' % var_name, + clip_ops.global_norm([grad_values])) + if not embedding_parallel: - for gradient, variable in gradients: - if isinstance(gradient, indexed_slices.IndexedSlices): - grad_values = gradient.values - else: - grad_values = gradient - - if grad_values is not None: - var_name = variable.name.replace(':', '_') - if 'gradients' in summaries: - summary.histogram('gradients/%s' % var_name, grad_values) - if 'gradient_norm' in summaries: - summary.scalar('gradient_norm/%s' % var_name, - clip_ops.global_norm([grad_values])) - - if clip_gradients is not None and ('global_gradient_norm' in summaries or - 'gradient_norm' in summaries): - sparse_norm, dense_norm, grad_norm = _get_grad_norm( - gradients, embedding_parallel) - summary.scalar('global_norm/clipped_sparse_grad', sparse_norm) - summary.scalar('global_norm/clipped_dense_grad', dense_norm) - summary.scalar('global_norm/clipped_gradient_norm', grad_norm) + if clip_gradients is not None and ('global_gradient_norm' in summaries or + 'gradient_norm' in summaries): + sparse_norm, dense_norm, grad_norm = _get_grad_norm( + gradients, embedding_parallel) + summary.scalar('global_norm/clipped_sparse_grad', sparse_norm) + summary.scalar('global_norm/clipped_dense_grad', dense_norm) + summary.scalar('global_norm/clipped_gradient_norm', grad_norm) # Create gradient updates. def _apply_grad():