import os
import re
import dynet as dy
from xnmt import logger
[docs]class ParamManager(object):
"""
A static class that manages the currently loaded DyNet parameters of all components.
Responsibilities are registering of all components that use DyNet parameters and loading pretrained parameters.
Components can register parameters by calling ParamManager.my_params(self) from within their __init__() method.
This allocates a subcollection with a unique identifier for this component. When loading previously saved parameters,
one or several paths are specified to look for the corresponding saved DyNet collection named after this identifier.
"""
initialized = False
[docs] @staticmethod
def init_param_col() -> None:
"""
Initializes or resets the parameter collection.
This must be invoked before every time a new model is loaded (e.g. on startup and between consecutive experiments).
"""
ParamManager.param_col = ParamCollection()
ParamManager.load_paths = []
ParamManager.initialized = True
[docs] @staticmethod
def add_load_path(data_file: str) -> None:
"""
Add new data directory path to load from.
When calling populate(), pretrained parameters from all directories added in this way are searched for the
requested component identifiers.
Args:
data_file: a data directory (usually named ``*.data``) containing DyNet parameter collections.
"""
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
if not data_file in ParamManager.load_paths: ParamManager.load_paths.append(data_file)
[docs] @staticmethod
def populate() -> None:
"""
Populate the parameter collections.
Searches the given data paths and loads parameter collections if they exist, otherwise leave parameters in their
randomly initialized state.
"""
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
populated_subcols = []
for subcol_name in ParamManager.param_col.subcols:
for load_path in ParamManager.load_paths:
data_file = os.path.join(load_path, subcol_name)
if os.path.isfile(data_file):
ParamManager.param_col.load_subcol_from_data_file(subcol_name, data_file)
populated_subcols.append(subcol_name)
if len(ParamManager.param_col.subcols) == len(populated_subcols):
logger.info(f"> populated DyNet weights of all components from given data files")
elif len(populated_subcols)==0:
logger.info(f"> use randomly initialized DyNet weights of all components")
else:
logger.info(f"> populated a subset of DyNet weights from given data files: {populated_subcols}.\n"
f" Did not populate {ParamManager.param_col.subcols.keys() - set(populated_subcols)}.\n"
f" If partial population was not intended, likely the unpopulated component or its owner"
f" does not adhere to the Serializable protocol correctly, see documentation:\n"
f" http://xnmt.readthedocs.io/en/latest/writing_xnmt_classes.html#using-serializable-subcomponents")
logger.info(f" DyNet param count: {ParamManager.param_col._param_col.parameter_count()}")
[docs] @staticmethod
def my_params(subcol_owner) -> dy.ParameterCollection:
"""Creates a dedicated parameter subcollection for a serializable object.
This should only be called from the __init__ method of a Serializable.
Args:
subcol_owner (Serializable): The object which is requesting to be assigned a subcollection.
Returns:
The assigned subcollection.
"""
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
assert not getattr(subcol_owner, "init_completed", False), \
f"my_params(obj) cannot be called after obj.__init__() has completed. Conflicting obj: {subcol_owner}"
if not hasattr(subcol_owner, "xnmt_subcol_name"):
raise ValueError(f"{subcol_owner} does not have an attribute 'xnmt_subcol_name'.\n"
f"Did you forget to wrap the __init__() in @serializable_init ?")
subcol_name = subcol_owner.xnmt_subcol_name
subcol = ParamManager.param_col.add_subcollection(subcol_owner, subcol_name)
subcol_owner.save_processed_arg("xnmt_subcol_name", subcol_name)
return subcol
[docs] @staticmethod
def global_collection() -> dy.ParameterCollection:
""" Access the top-level parameter collection, including all parameters.
Returns:
top-level DyNet parameter collection
"""
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
return ParamManager.param_col._param_col
class ParamCollection(object):
def __init__(self) -> None:
self.reset()
def reset(self) -> None:
self._save_num_checkpoints = 1
self._model_file = None
self._param_col = dy.Model()
self._is_saved = False
self.subcols = {}
self.all_subcol_owners = set()
@property
def save_num_checkpoints(self):
return self._save_num_checkpoints
@save_num_checkpoints.setter
def save_num_checkpoints(self, value):
self._save_num_checkpoints = value
self._update_data_files()
@property
def model_file(self):
return self._model_file
@model_file.setter
def model_file(self, value):
self._model_file = value
self._update_data_files()
def _update_data_files(self):
if self._save_num_checkpoints>0 and self._model_file:
self._data_files = [self.model_file + '.data']
for i in range(1,self._save_num_checkpoints):
self._data_files.append(self.model_file + '.data.' + str(i))
else:
self._data_files = []
def add_subcollection(self, subcol_owner: 'Serializable', subcol_name: str) -> dy.ParameterCollection:
assert subcol_owner not in self.all_subcol_owners
self.all_subcol_owners.add(subcol_owner)
if subcol_name in self.subcols:
raise RuntimeError(f'Duplicate subcol_name {subcol_name} found when loading')
new_subcol = self._param_col.add_subcollection(subcol_name)
self.subcols[subcol_name] = new_subcol
return new_subcol
def load_subcol_from_data_file(self, subcol_name: str, data_file: str) -> None:
self.subcols[subcol_name].populate(data_file)
def save(self) -> None:
if not self._is_saved:
self._remove_existing_history()
self._shift_saved_checkpoints()
if not os.path.exists(self._data_files[0]):
os.makedirs(self._data_files[0])
for subcol_name, subcol in self.subcols.items():
subcol.save(os.path.join(self._data_files[0], subcol_name))
self._is_saved = True
def revert_to_best_model(self) -> None:
if not self._is_saved:
raise RevertingUnsavedModelException("revert_to_best_model() is illegal because this model has never been saved.")
for subcol_name, subcol in self.subcols.items():
subcol.populate(os.path.join(self._data_files[0], subcol_name))
def _remove_existing_history(self):
for fname in self._data_files:
if os.path.exists(fname):
self._remove_data_dir(fname)
def _remove_data_dir(self, data_dir):
assert data_dir.endswith(".data") or data_dir.split(".")[-2] == "data"
try:
dir_contents = os.listdir(data_dir)
for old_file in dir_contents:
spl = old_file.split(".")
# make sure we're only deleting files with the expected filenames
if len(spl)==2:
if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", spl[0]):
if re.match(r"^[0-9a-f]{8}$", spl[1]):
os.remove(os.path.join(data_dir, old_file))
except NotADirectoryError:
os.remove(data_dir)
def _shift_saved_checkpoints(self):
if os.path.exists(self._data_files[-1]):
self._remove_data_dir(self._data_files[-1])
for i in range(len(self._data_files)-1)[::-1]:
if os.path.exists(self._data_files[i]):
os.rename(self._data_files[i], self._data_files[i+1])
[docs]class RevertingUnsavedModelException(Exception): pass