"""
This scripts contains some functions that are usually short, easy to understand.
"""
import sys
import collections
import logging
import json
import inspect
import pkg_resources
import os
import re
import copy
import numpy as np
from bson.objectid import ObjectId
import git
import tensorflow as tf
from tensorflow.python.client import device_lib
[docs]def isstring(x):
try:
x + ''
except:
return False
else:
return True
[docs]def version_info(module):
"""Get version of a standard python module.
Args:
module (module): python module object to get version info for.
Returns:
dict: dictionary of version info.
"""
if hasattr(module, '__version__'):
version = module.__version__
elif hasattr(module, 'VERSION'):
version = module.VERSION
else:
pkgname = module.__name__.split('.')[0]
try:
info = pkg_resources.get_distribution(pkgname)
except (pkg_resources.DistributionNotFound, pkg_resources.RequirementParseError):
version = None
log.warning(
'version information not found for %s -- what package is this from?' % module.__name__)
else:
version = info.version
return {'version': version}
[docs]def git_info(repo):
"""Return information about a git repo.
Args:
repo (git.Repo): The git repo to be investigated.
Returns:
dict: Git repo information
"""
if repo.is_dirty():
log.warning('repo %s is dirty -- having committment issues?' %
repo.git_dir)
clean = False
else:
clean = True
branchname = repo.active_branch.name
commit = repo.active_branch.commit.hexsha
origin = repo.remote('origin')
urls = map(str, list(origin.urls))
remote_ref = [_r for _r in origin.refs if _r.name ==
'origin/' + branchname]
if not len(remote_ref) > 0:
log.warning('Active branch %s not in origin ref' % branchname)
active_branch_in_origin = False
commit_in_log = False
else:
active_branch_in_origin = True
remote_ref = remote_ref[0]
gitlog = remote_ref.log()
shas = [_r.oldhexsha for _r in gitlog] + \
[_r.newhexsha for _r in gitlog]
if commit not in shas:
log.warning('Commit %s not in remote origin log for branch %s' % (commit,
branchname))
commit_in_log = False
else:
commit_in_log = True
info = {'git_dir': repo.git_dir,
'active_branch': branchname,
'commit': commit,
'remote_urls': urls,
'clean': clean,
'active_branch_in_origin': active_branch_in_origin,
'commit_in_log': commit_in_log}
return info
[docs]def get_available_gpus():
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == 'GPU']
[docs]def strip_prefix(prefix, all_vars):
# def _strip_prefix_from_name(prefix, name):
# prefix = prefix + '/' if not prefix.endswith('/') else prefix
# if name.startswith(prefix):
# name = name[len(prefix):]
# name = _strip_prefix_from_name(prefix, name)
# return name
var_list = {}
for var in all_vars:
new_name = strip_prefix_from_name(prefix, var.op.name)
var_list[new_name] = var
return var_list
[docs]def strip_prefix_from_name(prefix, name):
prefix = prefix + '/' if not prefix.endswith('/') else prefix
if name.startswith(prefix):
name = name[len(prefix):]
name = strip_prefix_from_name(prefix, name)
return name
[docs]def aggregate_outputs(tower_outputs):
"""Return aggregated model replicate outputs.
The elements of `tower_outputs` should have identical structure and
correspond to the outputs of individual model replicas on separate
devices (GPUs). Model replicate outputs are recursively searched until
a tensor ``t`` satisfying:
isinstance(t, tf.Tensor) -> True
is found. Tensor ``t`` is then concatenated with all of its corresponding
replicates along the batch dimension (axis=0).
If ``tower_outputs`` is a list of length one, then the element it contains
is returned.
Args:
tower_outputs (list): The outputs of individual model replicas.
Returns:
The aggregated output with a structure identical to the replicate outputs.
Raises:
TypeError: Aggregation not supported for given type.
Examples:
>>> print(tower_outputs)
[{'tensor': <tf.Tensor 'softmax_linear/fc/output:0' shape=(50, 10) dtype=float32>},
{'tensor': <tf.Tensor 'softmax_linear_1/fc/output:0' shape=(50, 10) dtype=float32>}]
>>>
>>> print(aggegrate_ouputs(tower_outputs))
{'tensor': <tf.Tensor 'concat:0' shape=(100, 10) dtype=float32>}
"""
if len(tower_outputs) == 1:
return tower_outputs.pop()
# Tensorflow tensors are concatenated along axis 0.
elif isinstance(tower_outputs[0], tf.Tensor):
return tf.concat(tower_outputs, axis=0)
# Tensorflow variables are not processed.
elif isinstance(tower_outputs[0], tf.Variable):
return tower_outputs
# Dict values are aggregated by key.
elif isinstance(tower_outputs[0], collections.Mapping):
return {key: aggregate_outputs([out[key] for out in tower_outputs])
for key in tower_outputs[0]}
# List elements are aggregated by index.
elif isinstance(tower_outputs[0], collections.Iterable):
return [aggregate_outputs(out) for out in zip(*tower_outputs)]
# All other types are not supported.
raise TypeError('Aggregation not supported for type: {}'.
format(type(tower_outputs[0])))
[docs]def identity_func(x):
if not hasattr(x, 'keys'):
x = {'result': x}
return x
[docs]def append_and_return(x, y, step):
if x is None:
x = []
x.append(y)
return x
[docs]def reduce_mean(x, y, step):
if x is None:
x = y
else:
f = step / (step + 1.)
x = f * x + (1 - f) * y
return x
[docs]def reduce_mean_dict(x, y, step):
if x is None:
x = {}
for k in y:
# ka = k + '_agg'
ka = k
if k != 'validation_step':
x[ka] = reduce_mean(x.get(ka), y[k], step)
else:
x[ka] = [min(min(x[ka]), y[k]), max(max(x[ka]), y[k])]
return x
[docs]def mean_dict(y):
x = {}
keys = y[0].keys()
for k in keys:
# ka = k + '_agg'
ka = k
pluck = [_y[k] for _y in y]
if k != 'validation_step':
x[ka] = np.mean(pluck)
else:
x[ka] = [min(pluck), max(pluck)]
return x
[docs]class frozendict(collections.Mapping):
"""An immuatable dictionary.
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.Mapping`
interface. It can be used as a drop-in replacement for dictionaries where immutability is desired.
from https://pypi.python.org/pypi/frozendict
"""
dict_cls = dict
def __init__(self, *args, **kwargs):
self._dict = self.dict_cls(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
[docs] def copy(self, **add_or_replace):
return self.__class__(self, **add_or_replace)
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __repr__(self):
return '<%s %r>' % (self.__class__.__name__, self._dict)
def __hash__(self):
if self._hash is None:
h = 0
for key, value in self._dict.items():
h ^= hash((key, value))
self._hash = h
return self._hash
[docs]def predict(step, results):
if not hasattr(results['output'], '__iter__'):
outputs = [results['outputs']]
else:
outputs = results['outputs']
preds = [tf.argmax(output, 1) for output in outputs]
return preds
# Useful function to average items within one dictionary
[docs]def online_agg(agg_res, res, step):
if agg_res is None:
agg_res = {k: [] for k in res}
for k, v in res.items():
agg_res[k].append(np.mean(v))
return agg_res