Example: Save and Load a TensorFlow Model
This post details how to save and load a TensorFlow model using the DNNClassifier
API.
The key idea here is that you define a function or a class beforehand that takes a model directory (in which it will save and restore the model parameters), adds that to RunConfig
, and returns a tf.contrib.learn.Estimator
, for example, tf.contrib.learn.DNNClassifier
. See make_estimator
for more details.
import numpy as np
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.training.python.training import hparam
MODEL_DIR = 'your-model-dir'
hparams = hparam.HParams(
num_epochs=10,
train_batch_size=50,
eval_batch_size=50,
eval_steps=10
)
def make_estimator(model_dir):
config = run_config.RunConfig(model_dir=model_dir)
input_columns = [
tf.feature_column.numeric_column(key='random_feature'),
]
return tf.contrib.learn.DNNClassifier(
config=config,
n_classes=2,
feature_columns=input_columns,
hidden_units=[1024, 512, 256],
)
dataset_size = 1000
X_train = np.random.rand(dataset_size)
y_train = np.random.rand(dataset_size)
X_eval = np.random.rand(dataset_size)
y_eval = np.random.rand(dataset_size)
estimator = make_estimator(MODEL_DIR)
experiment = tf.contrib.learn.Experiment(
estimator,
train_input_fn=tf.estimator.inputs.numpy_input_fn(
x={'random_feature': X_train},
y=y_train,
num_epochs=hparams.num_epochs,
batch_size=hparams.train_batch_size,
shuffle=True
),
eval_input_fn=tf.estimator.inputs.numpy_input_fn(
x={'random_feature': X_eval},
y=y_eval,
num_epochs=None,
batch_size=hparams.eval_batch_size,
shuffle=False # Don't shuffle evaluation data
)
)
experiment.train()
Loading the model
estimator_from_file = make_estimator(MODEL_DIR)
X_predict = np.array([0.3, 0.4, 0.5])
predict_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'random_feature': X_predict},
num_epochs=1,
batch_size=hparams.eval_batch_size,
shuffle=False # Don't shuffle evaluation data
)
predictions = estimator_from_file.predict_proba(input_fn=predict_input_fn)
for prediction in predictions:
print('Prediction:', prediction)