Add parameter server train & side-car eval on k8s#182
Add parameter server train & side-car eval on k8s#182selcukgun wants to merge 5 commits intotensorflow:masterfrom
Conversation
ResNet56 model (with custom training loop) variables are created on parameter server jobs, and updated by workers. Evaluation is done using a dedicated job which uses the checkpoints saved during the training (side-car evaluation). The model is trained on CIFAR10 dataset.
Jinja template now turns off side-car evaluation by default so that only the inline distributed evaluation added with this CL can be used.i README updated. Added efficiency wrappers that will be useful once GPU is supported with ParameterServerStrategy. Moved kubernetes jinja template and renderer script to dedicated subdirectory.
| Please first read the | ||
| [documentation](https://www.tensorflow.org/tutorials/distribute/parameter_server_training) | ||
| of Distribution Strategy for parameter server training. We also assume that readers | ||
| of this page are familiar with [Google Cloud](https://cloud.google.com/) and |
| - kubernetes/template.yaml.jinja: jinja template used for generating Kubernetes manifests | ||
| - kubernetes/render_template.py: script for rendering the jinja template | ||
| - Dockerfile.resnet_cifar_ps_strategy: a docker file to build the model image | ||
| - resnet_cifar_ps_strategy.py: script for running any type of parameter server training task based on `TF_CONFIG` environment variable |
There was a problem hiding this comment.
"any type of ..." seems too general, maybe just say "a ResNet example using Cifar dataset for parameter server training"
| BATCH_SIZE = 64 | ||
| EVAL_BATCH_SIZE = 8 | ||
|
|
||
| def create_in_process_cluster(num_workers, num_ps): |
There was a problem hiding this comment.
Could you update the work_config part according to this tutorial? https://www.tensorflow.org/tutorials/distribute/parameter_server_training#in-process_cluster
| set up distributed training | ||
| """ | ||
|
|
||
| strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( |
There was a problem hiding this comment.
let's use tf.distribute.experimental.ParameterServerStrategy.
| logging.info("Finished joining at epoch %d. Training accuracy: %f.", | ||
| epoch, train_accuracy.result()) | ||
|
|
||
| for _ in range(STEPS_PER_EPOCH): |
There was a problem hiding this comment.
Should evaluation use a different steps_per_epoch? since you have a different batch_size for evaluation.
There was a problem hiding this comment.
Good point. Introducing EVAL_STEPS_PER_EPOCH and setting it to 88 in the next patch shortly. This gives us a probability of 0.99 for a row in the dataset to be evaluated.
| logging.info("Finished joining at epoch %d. Training accuracy: %f.", | ||
| epoch, train_accuracy.result()) | ||
|
|
||
| for _ in range(STEPS_PER_EPOCH): |
There was a problem hiding this comment.
Could you add a comment here saying that we are running inline distributed evaluation, in this case an evaluator job is not necessary.
Also addressed the following: * Added inter_ops for workers * Replaced parameter_server_strategy_v2.ParameterServerStrategyV2 with tf.distribute.experimental.ParameterServerStrategy * Clarified resnet_cifar_ps_strategy.py description * Indicated that side-car evaluation job is ot needed since we are running inline-evaluation * Removed redundant spaces
| flags.DEFINE_string("data_dir", "gs://cifar10_data/", | ||
| "Directory for Resnet Cifar model input. Follow the " | ||
| "instruction here to get Cifar10 data: " | ||
| "https://github.com/tensorflow/models/tree/r1.13.0/official/resnet#cifar-10") |
There was a problem hiding this comment.
Split the help argument into multiple lines for readability; they are displayed as concatenated if help cmdline arg is passed.
| parse_record_fn=cifar_preprocessing.parse_record, | ||
| dtype=tf.float32, | ||
| drop_remainder=True) | ||
| eval_dataset_fn = lambda _: cifar_preprocessing.input_fn( |
There was a problem hiding this comment.
Is the eval data shuffled? If not, could you add a comment and a TODO?
There was a problem hiding this comment.
Maybe you can just append a shuffle at the end of the dataset?
There was a problem hiding this comment.
input_fn already shuffles the training data using process_record_dataset: code link
|
|
||
| # Since we are running inline evaluation below, a side-car evaluator job is not necessary. | ||
| for _ in range(EVAL_STEPS_PER_EPOCH): | ||
| coordinator.schedule(worker_eval_fn, args=(per_worker_eval_iterator,)) |
There was a problem hiding this comment.
We can probably build a similar API for DTensor async training. A major difficulty to sort out is what to do if worker_eval_fn( and or replica_fn) is multi-mesh -- for example if there is a summary Op that needs to run on a the CPU.
ResNet56 model (with custom training loop) variables are created on
parameter server jobs, and updated by workers. Evaluation is done using
a dedicated job which uses the checkpoints saved during the training
(side-car evaluation).
The model is trained on CIFAR10 dataset.