from tfutils.db_interface import DBInterface
from tfutils.helper import parse_params, log
from tfutils.validation import run_all_validations, get_valid_targets_dict
import tensorflow as tf
from tfutils.utils import strip_prefix
from tensorflow.python.ops import variables
import time
from tfutils.defaults import DEFAULT_HOST
[docs]def test(sess,
dbinterface,
validation_targets,
save_intermediate_freq=None):
"""
Actually runs the testing evaluation loop.
Args:
sess (tensorflow.Session): Object in which to run calculations
dbinterface (DBInterface object): Saver through which to save results
validation_targets (dict of tensorflow objects): Objects on which validation will be computed.
save_intermediate_freq (None or int): How frequently to save intermediate results captured during test
None means no intermediate saving will be saved
Returns:
dict: Validation summary.
dict: Results.
"""
# Collect args in a dict of lists
test_args = {
'dbinterface': dbinterface,
'validation_targets': validation_targets,
'save_intermediate_freq': save_intermediate_freq}
_ttargs = [{key: value[i] for (key, value) in test_args.items()}
for i in range(len(dbinterface))]
for ttarg in _ttargs:
ttarg['dbinterface'].start_time_step = time.time()
validation_summary = run_all_validations(
sess,
ttarg['validation_targets'],
save_intermediate_freq=ttarg['save_intermediate_freq'],
dbinterface=ttarg['dbinterface'],
validation_only=True)
res = []
for ttarg in _ttargs:
ttarg['dbinterface'].sync_with_host()
res.append(ttarg['dbinterface'].outrecs)
return validation_summary, res
[docs]def test_from_params(load_params,
model_params,
validation_params,
log_device_placement=False,
save_params=None,
dont_run=False,
skip_check=False,
):
"""
Main testing interface function.
Same as train_from_parameters; but just performs testing without training.
For documentation, see argument descriptions in train_from_params.
"""
params, test_args = parse_params(
'test',
model_params,
dont_run=dont_run,
skip_check=skip_check,
save_params=save_params,
load_params=load_params,
validation_params=validation_params,
log_device_placement=log_device_placement,
)
with tf.Graph().as_default(), tf.device(DEFAULT_HOST):
# create session
sess = tf.Session(
config=tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=log_device_placement,
))
init_op_global = tf.global_variables_initializer()
sess.run(init_op_global)
init_op_local = tf.local_variables_initializer()
sess.run(init_op_local)
log.info('Initialized from scratch first')
# For convenience, use list of dicts instead of dict of lists
_params = [{key: value[i] for (key, value) in params.items()}
for i in range(len(params['model_params']))]
_ttargs = [{key: value[i] for (key, value) in test_args.items()}
for i in range(len(params['model_params']))]
# Build a graph for each distinct model.
for param, ttarg in zip(_params, _ttargs):
if not 'cache_dir' in load_params:
temp_cache_dir = save_params.get('cache_dir', None)
load_params['cache_dir'] = temp_cache_dir
log.info('cache_dir not found in load_params, using cache_dir ({}) from save_params'.format(temp_cache_dir))
ttarg['dbinterface'] = DBInterface(params=param, load_params=param['load_params'])
ttarg['dbinterface'].load_rec()
ld = ttarg['dbinterface'].load_data
assert ld is not None, "No load data found for query, aborting"
ld = ld[0]
# TODO: have option to reconstitute model_params entirely from
# saved object ("revivification")
param['model_params']['seed'] = ld['params']['model_params']['seed']
cfg_final = ld['params']['model_params']['cfg_final']
ttarg['validation_targets'] = \
get_valid_targets_dict(
loss_params=None,
cfg_final=cfg_final,
**param)
# tf.get_variable_scope().reuse_variables()
param['load_params']['do_restore'] = True
param['model_params']['cfg_final'] = cfg_final
prefix = param['model_params']['prefix'] + '/'
all_vars = variables._all_saveable_objects()
var_list = strip_prefix(prefix, all_vars)
ttarg['dbinterface'] = DBInterface(sess=sess,
params=param,
var_list=var_list,
load_params=param['load_params'],
save_params=param['save_params'])
ttarg['dbinterface'].initialize(no_scratch=True)
ttarg['save_intermediate_freq'] = param['save_params'].get('save_intermediate_freq')
# Convert back to a dictionary of lists
params = {key: [param[key] for param in _params]
for key in _params[0].keys()}
test_args = {key: [ttarg[key] for ttarg in _ttargs]
for key in _ttargs[0].keys()}
if dont_run:
return test_args
res = test(sess, **test_args)
sess.close()
return res