KungFu documentation¶
KungFu aims to make distributed machine learning easy, adaptive and scalable.
Getting started¶
We try to keep it as simple as possible to install, deploy and run KungFu. KungFu does not require extra deployments like parameter servers or heavy dependencies like OpenMPI and NCCL as in Horovod. KungFu can run on your laptop, your desktop and your server, with and without GPUs. Please follow the instruction in the README to install KungFu.
Examples¶
We provide various examples to show how to use KungFu with various TensorFlow objects and Keras models.
Session¶
TensorFlow Session is a low-level but powerful interface that
allows you to compile a static graph for iterative training.
Session is the core for TensorFlow 1 programs. To enable KungFu,
you need to wrap your tf.train.Optimizer
in a KungFu
distributed optimizer, and
use BroadcastGlobalVariablesOp
to broadcast global variables
at the first step of your training.
import tensorflow as tf
# Build model...
loss = ...
opt = tf.train.AdamOptimizer(0.01)
# KungFu Step 1: Wrap tf.optimizer in KungFu optimizers
from kungfu.tensorflow.optimizers import SynchronousSGDOptimizer
opt = SynchronousSGDOptimizer(opt)
# Make training operation
train_op = opt.minimize(loss)
# Train your model
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# KungFu Step 2: ensure distributed workers start with consistent states
from kungfu.tensorflow.initializer import BroadcastGlobalVariablesOp
sess.run(BroadcastGlobalVariablesOp())
for step in range(10):
sess.run(train_op)
You can find the full training example: TensorFlow 1 Session
Estimator¶
TensorFlow Estimator is the high-level API for TensorFlow 1 programs.
To enable KungFu, you need to wrap your tf.train.Optimizer
in a KungFu
distributed optimizer, and
register BroadcastGlobalVariablesHook
as a hook for the estimator.
import tensorflow as tf
def model_func():
loss = ...
opt = tf.train.AdamOptimizer(0.01)
# KungFu Step 1: Wrap tf.optimizer in KungFu optimizers
from kungfu.tensorflow.optimizers import SynchronousAveragingOptimizer
opt = SynchronousAveragingOptimizer(opt)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=opt.minimize(loss))
# KungFu Step 2: register the broadcast global variables hook
from kungfu.tensorflow.initializer import BroadcastGlobalVariablesHook
hooks = [BroadcastGlobalVariablesHook()]
estimator = tf.estimator.Estimator(model_fn=model_func,
model_dir=FLAGS.model_dir)
for _ in range(10):
estimator.train(input_fn=train_data, hooks=hooks)
You can find the full training example: TensorFlow 1 Estimator
GradientTape¶
TensorFlow 2 supports eager execution for the ease of building dynamic models.
The core of the eager execution is the tf.GradientTape
.
To enable KungFu, you need to wrap your tf.train.Optimizer
in a KungFu
distributed optimizer, and use broadcast_variables
to broadcast global
variables at the end of the first step of training.
import tensorflow as tf
# Build the dataset...
dataset = ...
# Build model...
loss = ...
opt = tf.keras.optimizers.SGD(0.01)
# KungFu Step 1: Wrap tf.optimizer in KungFu optimizers
from kungfu.tensorflow.optimizers import SynchronousSGDOptimizer
opt = SynchronousSGDOptimizer(opt)
@tf.function
def training_step(images, labels, first_batch):
with tf.GradientTape() as tape:
probs = mnist_model(images, training=True)
loss_value = loss(labels, probs)
grads = tape.gradient(loss_value, mnist_model.trainable_variables)
opt.apply_gradients(zip(grads, mnist_model.trainable_variables))
# KungFu Step 2: broadcast global variables
if first_batch:
from kungfu.tensorflow.initializer import broadcast_variables
broadcast_variables(mnist_model.variables)
broadcast_variables(opt.variables())
return loss_value
for batch, (images, labels) in enumerate(dataset.take(10000)):
loss_value = training_step(images, labels, batch == 0)
You can find the full training example: TensorFlow 2 GradientTape
TensorFlow Keras¶
Keras has become the high-level training API for
TensorFlow since 1.11 and has become the default interface in TensorFlow 2.
To enable KungFu, you need to wrap your tf.train.Optimizer
in a KungFu
distributed optimizer, and use BroadcastGlobalVariablesCallback
as a callback for Keras model.
import tensorflow as tf
# Build dataset...
dataset = ....
# Build model...
model = tf.keras.Sequential(...)
opt = tf.keras.optimizers.SGD(0.01)
# KungFu Step 1: Wrap tf.optimizer in KungFu optimizers
from kungfu.tensorflow.optimizers import SynchronousSGDOptimizer
opt = SynchronousSGDOptimizer(opt)
model.compile(loss=tf.losses.SparseCategoricalCrossentropy(),
optimizer=opt,
metrics=['accuracy'])
# KungFu Step 2: Register a broadcast callback
from kungfu.tensorflow.initializer import BroadcastGlobalVariablesCallback
model.fit(dataset,
steps_per_epoch=500,
epochs=1,
callbacks=[BroadcastGlobalVariablesCallback()])
Here are two full training examples: TensorFlow 1 Keras and TensorFlow 2 Keras
KungFu APIs¶
KungFu has the high-level optimizer APIs that allows you to transparently scale out training. It also has a low-level API that allows an easy implementation of distributed training strategies. The following is the public API we released so far.
Distributed optimizers¶
KungFu provides optimizers that implement various distributed training algorithms. These optimizers are used for transparently scaling out the training of tf.train.Optimizer and tf.keras.optimizers.Optimizer
-
kungfu.tensorflow.optimizers.
PairAveragingOptimizer
(optimizer, fuse_requests=True, fused_model_name=None, name=None, use_locking=False, with_keras=False)¶ PairAveragingOptimizer implements the [AD-PSGD] algorithm.
Every iteration of training, this optimizer:
- Randomly selects a peer in the current cluster.
- Pulls the selected peer’s model
- Performs model averaging with the local model.
- Applies local gradients
- Saves the model to a local store which allows other peers to pull from.
[AD-PSGD] Asynchronous Decentralized Parallel Stochastic Gradient Descent, ICML 2018, AD-PSGD Paper - Arguments:
- optimizer {tf.train.Optimizer, tf.keras.optimizers.Optimizer} – Optimizer to use for computing gradients and applying updates.
- Keyword Arguments:
- fuse_requests {bool} – Fusing requests to amortise communication cost at the cost of extra GPU memory and cycles. (default: {True})
- fused_model_name {str} – The unique name for the fused model kept in a local store. (default: {None})
- name {str} – name prefix for the operations created when applying gradients. Defaults to “KungFu” followed by the provided optimizer type. (default: {None})
- use_locking {bool} – Whether to use locking when updating variables. (default: {False})
- with_keras {bool} – Runs with pure Keras or not (default: {False})
- Raises:
- TypeError: Wrapped optimizer is not a subclass of tf.train.Optimizer or tf.keras.optimizers.Optimizer
- Returns:
- optimizer {tf.train.Optimizer, tf.keras.optimizers.Optimizer} – KungFu distributed optimizer
-
kungfu.tensorflow.optimizers.
SynchronousSGDOptimizer
(optimizer, nccl=False, nccl_fusion=False, hierarchical_nccl=False, monitor=False, name=None, use_locking=False, with_keras=False)¶ SynchronousSGDOptimizer implements the [S-SGD] algorithm.
This optimizer is equivalent to the DistributedOptimizer in Horovod. Every iteration of training, this optimizer computes the averaged gradients to correct diverged model replicas.
[S-SGD] Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour, 2017, S-SGD Paper - Arguments:
- optimizer {tf.train.Optimizer, tf.keras.optimizers.Optimizer} – Optimizer to use for computing gradients and applying updates.
- Keyword Arguments:
- nccl {bool} – using NCCL to average gradients. (default: {False})
- nccl_fusion {bool} – fusing all gradients to amortise NCCL operation launch cost. (default: {True})
- name {str} – name prefix for the operations created when applying gradients. Defaults to “KungFu” followed by the provided optimizer type. (default: {None})
- use_locking {bool} – Whether to use locking when updating variables. (default: {False})
- with_keras {bool} – Runs with pure Keras or not (default: {False})
- Raises:
- TypeError: Wrapped optimizer is not a subclass of tf.train.Optimizer or tf.keras.optimizers.Optimizer
- Returns:
- optimizer {tf.train.Optimizer, tf.keras.optimizers.Optimizer} – KungFu distributed optimizer
-
kungfu.tensorflow.optimizers.
SynchronousAveragingOptimizer
(optimizer, name=None, alpha=0.1, use_locking=False, with_keras=False)¶ SynchronousAveragingOptimizer implements the [SMA] algorithm.
[EA-SGD] proposed to use model averaging to train deep learning models and prove its convergence. [SMA] further improves [EA-SGD] results and show model averaging can benefit small-batch training and achieves fast convergence compared to synchronous SGD.
[EA-SGD] (1, 2) Deep learning with Elastic Averaging SGD, NIPS 2015, EA-SGD Paper [SMA] (1, 2) CrossBow: Scaling Deep Learning with Small Batch Sizes on Multi-GPU Servers, VLDB 2019, SMA Paper - Arguments:
- optimizer {tf.train.Optimizer, tf.keras.optimizers.Optimizer} – Optimizer to use for computing gradients and applying updates.
- Keyword Arguments:
- name {str} – name prefix for the operations created when applying gradients. Defaults to “KungFu” followed by the provided optimizer type. (default: {None})
- alpha {float} – the ratio of a central model during averaging (Check the SMA and EA-SGD papers for its intuition). (default: {0.1})
- use_locking {bool} – Whether to use locking when updating variables. (default: {False})
- with_keras {bool} – Runs with pure Keras or not (default: {False})
- Raises:
- TypeError: Wrapped optimizer is not a subclass of tf.train.Optimizer or tf.keras.optimizers.Optimizer
- Returns:
- optimizer {tf.train.Optimizer, tf.keras.optimizers.Optimizer} – KungFu distributed optimizer
Global variable initializers¶
KungFu provide various initializers to help you synchronize
the global variables of distributed training workers at
the beginning of training. These initializers are used
with tf.session
, tf.estimator
, tf.GradientTape
and tf.keras
, respectively.
-
kungfu.tensorflow.initializer.
broadcast_variables
(variables)¶ A TensorFlow function that broadcasts global variables.
This function is often used with
tf.GradientTape
or embedded as part of a training program.
-
kungfu.tensorflow.initializer.
BroadcastGlobalVariablesCallback
(with_keras=False)¶ Keras callback that broadcasts global variables at the begining of training.
- Keyword Arguments:
- with_keras {bool} – Runs with pure Keras or not (default: {False})
- Returns:
- {tf.keras.callbacks.Callback, keras.callbacks.Callback} – Callback
-
class
kungfu.tensorflow.initializer.
BroadcastGlobalVariablesHook
¶ A TensorFlow hook that broadcasts global variables at the begining of training.
This hook is often used with
tf.session.MonitoredSession
andtf.train.Estimator
.-
after_create_session
(session, coord)¶ Broadcast global vartiables after creating the session.
-
begin
()¶ Create a broadcast op at the beginning.
-
-
kungfu.tensorflow.initializer.
BroadcastGlobalVariablesOp
()¶ A TensorFlow operator that broadcasts global variables.
This operator if often used with the low-level tf.Session
Cluster management¶
When scaling out training, you often want to adjust the parameters of your training program, for example, sharding the training dataset or scaling the learning rate of the optimizer. This can be achieved using the following cluster management APIs.
-
kungfu.python.
current_cluster_size
()¶ Get the number of peers in the current cluster.
-
kungfu.python.
current_local_rank
()¶ Get the current local rank of this peer.
-
kungfu.python.
current_local_size
()¶ Get the number of local peers in the current cluster.
-
kungfu.python.
current_rank
()¶ Get the current rank of this peer.
-
kungfu.python.
detached
()¶ Check if the peer is detached.
-
kungfu.python.
run_barrier
()¶ Run the barrier operation eagerly.
TensorFlow operators¶
KungFu provides TensorFlow operators to help you realise new distributed training optimizers.
-
kungfu.tensorflow.ops.
all_gather
(t)¶ Create a new all_gather operator for given tensor.
- Inputs:
- A tensor of any shape. The shape must be consistent on all peers.
- Returns:
- A tensor with leading dimension equal to the number of peers, and the rest dimensions equal to the dimensions in the original shape.
-
kungfu.tensorflow.ops.
barrier
()¶ Create a new barrier operator.
-
kungfu.tensorflow.ops.
broadcast
(t)¶ Create a new broadcast operator for given tensor.
-
kungfu.tensorflow.ops.
cluster_size
()¶ - Returns:
- a scalar tensor of int32 representing the cluster size.
-
kungfu.tensorflow.ops.
group_all_reduce
(ts)¶ Create a list of all_reduce operators for given tensor list.
-
kungfu.tensorflow.ops.
rank
()¶ - Returns:
- a scalar tensor of int32 representing the rank.
-
kungfu.tensorflow.ops.
resize
(n)¶ Resize the cluster to n.
- Inputs:
- n: A scalar tensor of uint32.
- Returns:
- A scalar tensor of bool, indicates if the cluster has been changed.
-
kungfu.tensorflow.ops.
set_tree
(tree)¶ Set the default communication tree.
- Inputs:
- tree: an int32 tensor with shape [n], where
- n is the number of peers in the current cluster;
- tree[i] is the father of i if tree[i] != i;
- i is the root if tree[i] == i.