"""Test base module."""
import os
import re
import sys
import errno
import shutil
import cPickle
import logging
import unittest
from collections import defaultdict
import copy
import gridfs
import pymongo
import tensorflow as tf
import pdb
import mnist_data as data
sys.path.insert(0, "..")
import tfutils.base as base
import tfutils.model as model
import tfutils.utils as utils
from tfutils.db_interface import TFUTILS_HOME
TRAIN_PARAMS = {
'data_params': {'func': data.build_data,
'batch_size': 100,
'group': 'train',
'directory': TFUTILS_HOME},
'num_steps': 500}
VALIDATION_PARAMS = {
'valid0': {
'data_params': {
'func': data.build_data,
'batch_size': 100,
'group': 'test',
'directory': TFUTILS_HOME},
'num_steps': 10,
'agg_func': utils.mean_dict}}
NUM_BATCHES_PER_EPOCH = 10000 // 256
LEARNING_RATE_PARAMS = {
'learning_rate': 0.05,
'decay_steps': NUM_BATCHES_PER_EPOCH,
'decay_rate': 0.95,
'staircase': True}
MODEL_BUILD_FUNC = model.mnist_tfutils_new
[docs]def setUpModule():
"""Set up module once, before any TestCases are run."""
logging.basicConfig()
[docs]def tearDownModule():
"""Tear down module after all TestCases are run."""
pass
[docs]class TestBase(unittest.TestCase):
port = 29101
host = 'localhost'
database_name = '_tfutils'
[docs] @classmethod
def setUpClass(cls):
"""Set up class once before any test methods are run."""
cls.log = logging.getLogger(':'.join([__name__, cls.__name__]))
cls.log.setLevel('DEBUG')
# Open primary MongoDB connection.
cls.conn = pymongo.MongoClient(host=cls.host,
port=cls.port)
cls.collection_name = cls.__name__
cls.collection = cls.conn[cls.database_name][cls.collection_name]
[docs] @classmethod
def tearDownClass(cls):
"""Tear down class after all test methods have run."""
cls.remove_database(cls.database_name)
[cls.conn.drop_database(x)
for x in cls.conn.database_names()
if x.startswith(cls.database_name)]
# TODO: Remove cache, if desired.
# Close primary MongoDB connection.
cls.conn.close()
[docs] def setUp(self):
"""Set up class before each test method is executed."""
pass
[docs] def tearDown(self):
"""Tear Down is called after each test method is executed."""
self.remove_collection(self.collection_name)
[docs] def setup_params(self, exp_id):
"""Specify training params with a specific exp_id."""
params = {}
params['model_params'] = {
'func': MODEL_BUILD_FUNC,
#'devices': ['/gpu:0', '/gpu:1'],
}
params['save_params'] = {
'host': self.host,
'port': self.port,
'dbname': self.database_name,
'collname': self.collection_name,
'exp_id': exp_id,
'save_valid_freq': 20,
'save_filters_freq': 200,
'cache_filters_freq': 100}
params['train_params'] = copy.deepcopy(TRAIN_PARAMS)
params['learning_rate_params'] = copy.deepcopy(LEARNING_RATE_PARAMS)
params['validation_params'] = copy.deepcopy(VALIDATION_PARAMS)
params['skip_check'] = True
return params
[docs] def test_training(self):
"""Illustrate training.
This test illustrates how basic training is performed using the
tfutils.base.train_from_params function. This is the first in a sequence of
interconnected tests. It creates a pretrained model that is used by
the next few tests (test_validation and test_feature_extraction).
As can be seen by looking at how the test checks for correctness, after the
training is run, results of training, including (intermittently) the full
variables needed to re-initialize the tensorflow model, are stored in a
MongoDB.
Also see docstring of the tfutils.base.train_from_params function for more detailed
information about usage.
"""
exp_id = 'training0'
params = self.setup_params(exp_id)
# Run training.
base.train_from_params(**params)
# Test if results are as expected.
self.assert_as_expected(exp_id, count=26, step=[0, 200, 400])
r = self.collection['files'].find({'exp_id': exp_id, 'step': 0})[0]
self.asserts_for_record(r, params, train=True)
r = self.collection['files'].find({'exp_id': exp_id, 'step': 20})[0]
self.asserts_for_record(r, params, train=True)
# Run another 500 steps of training on the same experiment id.
params['train_params']['num_steps'] = 1000
base.train_from_params(**params)
# Test if results are as expected.
self.assert_as_expected(exp_id, 51, [0, 200, 400, 600, 800, 1000])
self.assertEqual(self.collection['files'].distinct('exp_id'), [exp_id])
r = self.collection['files'].find({'exp_id': exp_id, 'step': 1000})[0]
self.asserts_for_record(r, params, train=True)
# Run 500 more steps but save to a new experiment id.
new_exp_id = 'training1'
params['train_params']['num_steps'] = 1500
params['load_params'] = {'exp_id': exp_id}
params['save_params']['exp_id'] = new_exp_id
base.train_from_params(**params)
self.assert_step(new_exp_id, [1200, 1400])
[docs] def test_custom_training(self):
"""Illustrate training with custom training loop.
This test illustrates how basic training is performed with a custom
training loop using the tfutils.base.train_from_params function.
"""
exp_id = 'training0'
params = self.setup_params(exp_id)
# Add a custom train_loop to use during training.
params['train_params']['train_loop'] = {'func': self.custom_train_loop}
base.train_from_params(**params)
[docs] def test_training_save(self):
"""Illustrate saving to the grid file system during training time."""
exp_id = 'training_save'
params = self.setup_params(exp_id)
# Modify a few of the save parameters.
params['save_params']['save_valid_freq'] = 3000
params['save_params']['save_filters_freq'] = 30000
params['save_params']['cache_filters_freq'] = 3000
# Specify additional save_params for saving to gfs.
params['save_params']['save_to_gfs'] = ['first_image']
params['train_params']['targets'] = {'func': self.get_first_image_target}
# Actually run the training.
base.train_from_params(**params)
# Check that the first image has been saved.
coll = self.collection['files']
q = {'exp_id': exp_id, 'train_results': {'$exists': True}}
train_steps = coll.find(q)
self.assertEqual(train_steps.count(), 5)
idx = train_steps[0]['_id']
fn = coll.find({'item_for': idx})[0]['filename']
fs = gridfs.GridFS(coll.database, self.collection_name)
fh = fs.get_last_version(fn)
saved_data = cPickle.loads(fh.read())
fh.close()
# Assert as expected.
self.assertIn('train_results', saved_data)
self.assertIn('first_image', saved_data['train_results'])
self.assertEqual(len(saved_data['train_results']['first_image']), 100)
self.assertEqual(saved_data['train_results']['first_image'][0].shape, (28 * 28,))
[docs] def test_validation(self):
"""Illustrate validation.
This is a test illustrating how to compute performance on a trained model on a new dataset,
using the tfutils.base.test_from_params function. This test assumes that test_training function
has run first (to provide a pre-trained model to validate).
After the test is run, results from the validation are stored in the MongoDB.
(The test shows how the record can be loaded for inspection.)
See the docstring of tfutils.base.test_from_params for more detailed information on usage.
"""
# Specify the parameters for the validation.
exp_id = 'training0'
val_exp_id = 'validation0'
params = self.setup_params(exp_id)
params.pop('train_params')
params.pop('learning_rate_params')
params['load_params'] = params['save_params']
params['save_params'] = {'exp_id': val_exp_id}
# Actually run the model.
base.test_from_params(**params)
# ... specifically, there is now a record containing the validation0 performance results
self.assertEqual(self.collection['files'].find({'exp_id': val_exp_id}).count(), 1)
# ... here's how to load the record:
r = self.collection['files'].find({'exp_id': val_exp_id})[0]
self.asserts_for_record(r, params, train=False)
# ... check that the recorrectly ties to the id information for the
# pre-trained model it was supposed to validate
self.assertTrue(r['validates'])
idval = self.collection['files'].find({'exp_id': exp_id})[50]['_id']
v = self.collection['files'].find({'exp_id': val_exp_id})[0]['validates']
self.assertEqual(idval, v)
[docs] def assert_count(self, exp_id, count):
self.assertEqual(
self.collection['files'].find({'exp_id': exp_id}).count(),
count)
[docs] def assert_step(self, exp_id, step):
self.assertEqual(
self.collection['files']
.find({'exp_id': exp_id, 'saved_filters': True})
.distinct('step'),
step)
[docs] def assert_as_expected(self, exp_id, count=None, step=None):
if count is not None:
self.assert_count(exp_id, count)
if step is not None:
self.assert_step(exp_id, step)
[docs] @classmethod
def remove_directory(cls, directory):
"""Remove a directory."""
cls.log.info('Removing directory: {}'.format(directory))
shutil.rmtree(directory)
cls.log.info('Directory successfully removed.')
[docs] @classmethod
def remove_database(cls, database_name):
"""Remove a MonogoDB database."""
cls.log.info('Removing database: {}'.format(database_name))
cls.conn.drop_database(database_name)
cls.log.info('Database successfully removed.')
[docs] @classmethod
def remove_collection(cls, collection_name):
"""Remove a MonogoDB collection."""
cls.log.info('Removing collection: {}'.format(collection_name))
cls.conn[cls.database_name][collection_name].drop()
cls.log.info('Collection successfully removed.')
[docs] @classmethod
def remove_document(cls, document):
raise NotImplementedError
[docs] @staticmethod
def makedirs(dir):
try:
os.makedirs(dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
[docs] @staticmethod
def get_first_image_target(inputs, outputs, **ttarg_params):
"""Return target for saving the first image of every batch.
Used in test_training_save test to test save_to_gfs option.
"""
return {'first_image': inputs['images'][0]}
[docs] @staticmethod
def custom_train_loop(sess, train_targets, **loop_params):
"""Define Custom training loop.
Args:
sess (tf.Session): Current tensorflow session.
train_targets (list): Description.
**loop_params: Optional kwargs needed to perform custom train loop.
Returns:
dict: A dictionary containing train targets evaluated by the session.
"""
train_results = sess.run(train_targets)
for i, result in enumerate(train_results):
print('Model {} has loss {}'.format(i, result['loss']))
return train_results
[docs] @staticmethod
def asserts_for_record(r, params, train=False):
if r.get('saved_filters'):
assert r['_saver_write_version'] == 2
assert r['_saver_num_data_files'] == 1
assert type(r['duration']) == float
should_contain = ['save_params', 'load_params',
'model_params', 'validation_params']
assert set(should_contain).difference(r['params'].keys()) == set()
vk = r['params']['validation_params'].keys()
vk1 = r['validation_results'].keys()
assert set(vk) == set(vk1)
assert r['params']['model_params']['seed'] == 0
assert r['params']['model_params']['func']['modname'] == 'tfutils.model_tool'
assert r['params']['model_params']['func']['objname'] == 'mnist_tfutils'
_k = vk[0]
should_contain = ['agg_func', 'data_params', 'num_steps',
'online_agg_func', 'targets']
assert set(should_contain).difference(
r['params']['validation_params'][_k].keys()) == set()
if train:
assert r['params']['model_params']['train'] is True
for k in ['num_steps']:
assert r['params']['train_params'][k] \
== params['train_params'][k]
should_contain = ['loss_params', 'optimizer_params',
'train_params', 'learning_rate_params']
assert set(should_contain).difference(r['params'].keys()) == set()
r_train_params = r['params']['train_params']
assert r_train_params['thres_loss'] == 100
assert r_train_params['data_params']['func']['modname'] \
== 'mnist_data', \
r_train_params['data_params']['func']['modname']
assert r_train_params['data_params']['func']['objname'] \
== 'build_data',\
r_train_params['data_params']['func']['objname']
assert r['params']['loss_params']['agg_func']['modname'] == 'tensorflow.python.ops.math_ops'
assert r['params']['loss_params']['agg_func']['objname'] == 'reduce_mean'
assert r['params']['loss_params']['loss_func']['modname'] == 'tensorflow.python.ops.nn_ops'
assert r['params']['loss_params']['loss_func']['objname'] == 'sparse_softmax_cross_entropy_with_logits'
assert r['params']['loss_params']['pred_targets'] == ['labels']
else:
assert not r['params']['model_params']['train']
assert 'train_params' not in r['params']
[docs]class TestDistributedModel(TestBase):
[docs] def setup_params(self, exp_id):
params = {}
params['model_params'] = {
'func': MODEL_BUILD_FUNC,
'devices': ['/gpu:0', '/gpu:1']}
params['save_params'] = {
'host': self.host,
'port': self.port,
'dbname': self.database_name,
'collname': self.collection_name,
'exp_id': exp_id,
'save_valid_freq': 20,
'save_filters_freq': 200,
'cache_filters_freq': 100}
params['train_params'] = copy.deepcopy(TRAIN_PARAMS)
params['learning_rate_params'] = copy.deepcopy(LEARNING_RATE_PARAMS)
params['validation_params'] = copy.deepcopy(VALIDATION_PARAMS)
params['skip_check'] = True
return params
[docs]class TestMultiModel(TestBase):
[docs] def setup_params(self, exp_id):
params = {}
params['model_params'] = [
{'func': MODEL_BUILD_FUNC},
{'func': MODEL_BUILD_FUNC}]
params['save_params'] = {
'host': self.host,
'port': self.port,
'dbname': self.database_name,
'collname': self.collection_name,
'exp_id': exp_id,
'save_valid_freq': 20,
'save_filters_freq': 200,
'cache_filters_freq': 100}
params['train_params'] = copy.deepcopy(TRAIN_PARAMS)
params['learning_rate_params'] = copy.deepcopy(LEARNING_RATE_PARAMS)
params['validation_params'] = copy.deepcopy(VALIDATION_PARAMS)
params['skip_check'] = True
return params
[docs] def test_training(self):
base_exp_id = 'training0'
params = self.setup_params(base_exp_id)
num_models = len(params['model_params'])
# Actually run the training.
base.train_from_params(**params)
# Test if results are as expected.
for i in range(num_models):
exp_id = base_exp_id + '_model_{}'.format(i)
self.assert_as_expected(exp_id, count=26, step=[0, 200, 400])
r = self.collection['files'].find({'exp_id': exp_id, 'step': 0})[0]
self.asserts_for_record(r, params, train=True)
r = self.collection['files'].find({'exp_id': exp_id, 'step': 20})[0]
self.asserts_for_record(r, params, train=True)
# Run another 500 steps of training on the same experiment id.
params['train_params']['num_steps'] = 1000
base.train_from_params(**params)
# Test if results are as expected.
for i in range(num_models):
exp_id = base_exp_id + '_model_{}'.format(i)
self.assert_as_expected(exp_id, 51, [0, 200, 400, 600, 800, 1000])
self.assertItemsEqual(
self.collection['files'].distinct('exp_id'),
[base_exp_id + '_model_{}'.format(i) for i in range(num_models)])
r = self.collection['files'].find({'exp_id': exp_id, 'step': 1000})[0]
self.asserts_for_record(r, params, train=True)
# Run 500 more steps but save to a new experiment id.
new_exp_id = 'training1'
params['train_params']['num_steps'] = 1500
params['load_params'] = {'exp_id': base_exp_id}
params['save_params']['exp_id'] = new_exp_id
base.train_from_params(**params)
for i in range(num_models):
exp_id = new_exp_id + '_model_{}'.format(i)
self.assert_step(exp_id, [1200, 1400])
[docs] def test_training_save(self):
"""Illustrate saving to the grid file system during training time."""
base_exp_id = 'training_save'
params = self.setup_params(base_exp_id)
num_models = len(params['model_params'])
params['save_params']['save_to_gfs'] = ['first_image']
params['save_params']['save_valid_freq'] = 3000
params['save_params']['save_filters_freq'] = 30000
params['save_params']['cache_filters_freq'] = 3000
params['train_params']['targets'] = {'func': self.get_first_image_target}
# Actually run the training.
base.train_from_params(**params)
# Check that the first image has been saved.
for i in range(num_models):
exp_id = base_exp_id + '_model_{}'.format(i)
coll = self.collection['files']
q = {'exp_id': exp_id, 'train_results': {'$exists': True}}
train_steps = coll.find(q)
self.assertEqual(train_steps.count(), 5)
idx = train_steps[0]['_id']
fn = coll.find({'item_for': idx})[0]['filename']
fs = gridfs.GridFS(coll.database, self.collection_name)
fh = fs.get_last_version(fn)
saved_data = cPickle.loads(fh.read())
fh.close()
self.assertIn('train_results', saved_data)
self.assertIn('first_image', saved_data['train_results'])
self.assertEqual(len(saved_data['train_results']['first_image']), 100)
self.assertEqual(saved_data['train_results']['first_image'][0].shape, (28 * 28,))
[docs] def test_validation(self):
# Specify the parameters for the validation.
base_exp_id = 'training0'
base_val_exp_id = 'validation0'
params = self.setup_params(base_exp_id)
num_models = len(params['model_params'])
params.pop('train_params')
params.pop('learning_rate_params')
params['load_params'] = params['save_params']
params['save_params'] = {'exp_id': base_val_exp_id}
# Actually run the model
base.test_from_params(**params)
# Check that the results are correct.
for i in range(num_models):
exp_id = base_exp_id + '_model_{}'.format(i)
val_exp_id = base_val_exp_id + '_model_{}'.format(i)
# ... specifically, there is now a record containing the validation0 performance results
self.assertEqual(self.collection['files'].find({'exp_id': val_exp_id}).count(), 1)
# ... here's how to load the record:
r = self.collection['files'].find({'exp_id': val_exp_id})[0]
self.asserts_for_record(r, params, train=False)
# ... check that the recorrectly ties to the id information for the
# pre-trained model it was supposed to validate
self.assertTrue(r['validates'])
idval = self.collection['files'].find({'exp_id': exp_id})[50]['_id']
v = self.collection['files'].find({'exp_id': val_exp_id})[0]['validates']
self.assertEqual(idval, v)
[docs]class TestDistributedMulti(TestMultiModel):
[docs] def setup_params(self, exp_id):
params = {}
params['model_params'] = [
{'func': MODEL_BUILD_FUNC,
'devices': ['/gpu:0', '/gpu:1']},
{'func': MODEL_BUILD_FUNC,
'devices': ['/gpu:2', '/gpu:3']}]
params['save_params'] = {
'host': self.host,
'port': self.port,
'dbname': self.database_name,
'collname': self.collection_name,
'exp_id': exp_id,
'save_valid_freq': 20,
'save_filters_freq': 200,
'cache_filters_freq': 100}
params['train_params'] = copy.deepcopy(TRAIN_PARAMS)
params['learning_rate_params'] = copy.deepcopy(LEARNING_RATE_PARAMS)
params['validation_params'] = copy.deepcopy(VALIDATION_PARAMS)
params['skip_check'] = True
return params
if __name__ == '__main__':
unittest.main()