Source code for xnmt.persistence

"""
This module takes care of loading and saving YAML files. Both configuration files and saved models are stored in the
same YAML file format.

The main objects to be aware of are:

* :class:`Serializable`: must be subclassed by all components that are specified in a YAML file.
* :class:`Ref`: a reference that points somewhere in the object hierarchy, for both convenience and to realize parameter sharing.
* :class:`Repeat`: a syntax for creating a list components with same configuration but without parameter sharing.
* :class:`YamlPreloader`: pre-loads YAML contents so that some infrastructure can be set up, but does not initialize components.
* :meth:`initialize_if_needed`, :meth:`initialize_object`: initialize a preloaded YAML tree, taking care of resolving references etc.
* :meth:`save_to_file`: saves a YAML file along with registered DyNet parameters
* :class:`LoadSerialized`: can be used to load, modify, and re-assemble pretrained models.
* :meth:`bare`: create uninitialized objects, usually for the purpose of specifying them as default arguments.
* :class:`RandomParam`: a special Serializable subclass that realizes random parameter search.

"""

from functools import singledispatch
from enum import IntEnum, auto
import collections.abc
import numbers
import logging
logger = logging.getLogger('xnmt')
import os
import copy
from functools import lru_cache, wraps
from collections import OrderedDict
import collections.abc
from typing import List, Set, Callable, TypeVar, Type, Union, Optional, Dict, Any
import inspect, random

import yaml

from xnmt import param_collections, tee, utils
import xnmt

def serializable_init(f):
  @wraps(f)
  def wrapper(obj, *args, **kwargs):
    if "xnmt_subcol_name" in kwargs:
      xnmt_subcol_name = kwargs.pop("xnmt_subcol_name")
    elif hasattr(obj, "xnmt_subcol_name"): # happens when calling wrapped super() constructors
      xnmt_subcol_name = obj.xnmt_subcol_name
    else:
      xnmt_subcol_name = _generate_subcol_name(obj)
    obj.xnmt_subcol_name = xnmt_subcol_name
    serialize_params = dict(kwargs)
    params = inspect.signature(f).parameters
    if len(args) > 0:
      param_names = [p.name for p in list(params.values())]
      assert param_names[0] == "self"
      param_names = param_names[1:]
      for i, arg in enumerate(args):
        serialize_params[param_names[i]] = arg
    auto_added_defaults = set()
    for param in params.values():
      if param.name != "self" and param.default != inspect.Parameter.empty and param.name not in serialize_params:
        serialize_params[param.name] = copy.deepcopy(param.default)
        auto_added_defaults.add(param.name)
    for arg in serialize_params.values():
      if type(obj).__name__ != "Experiment":
        assert type(arg).__name__ != "ExpGlobal", \
          "ExpGlobal can no longer be passed directly. Use a reference to its properties instead."
        assert type(arg).__name__ != "ParameterCollection", \
          "cannot pass dy.ParameterCollection to a Serializable class. " \
          "Use ParamManager.my_params() from within the Serializable class's __init__() method instead."
    for key, arg in list(serialize_params.items()):
      if isinstance(arg, Ref):
        if not arg.is_required():
          serialize_params[key] = copy.deepcopy(arg.get_default())
        else:
          if key in auto_added_defaults:
            raise ValueError(
              f"Required argument '{key}' of {type(obj).__name__}.__init__() was not specified, and {arg} could not be resolved")
          else:
            raise ValueError(
              f"Cannot pass a reference as argument; received {serialize_params[key]} in {type(obj).__name__}.__init__()")
    for key, arg in list(serialize_params.items()):
      if getattr(arg, "_is_bare", False):
        initialized = initialize_object(UninitializedYamlObject(arg))
        assert not getattr(initialized, "_is_bare", False)
        serialize_params[key] = initialized
    f(obj, **serialize_params)
    if param_collections.ParamManager.initialized and xnmt_subcol_name in param_collections.ParamManager.param_col.subcols:
      serialize_params["xnmt_subcol_name"] = xnmt_subcol_name
    serialize_params.update(getattr(obj, "serialize_params", {}))
    if "yaml_path" in serialize_params: del serialize_params["yaml_path"]
    obj.serialize_params = serialize_params
    obj.init_completed = True
    # TODO: the below is needed for proper reference creation when saving the model, but should be replaced with
    # something safer
    for key, arg in serialize_params.items():
      if not hasattr(obj, key):
        setattr(obj, key, arg)

  wrapper.uses_serializable_init = True
  return wrapper

[docs]class Serializable(yaml.YAMLObject): """ All model components that appear in a YAML file must inherit from Serializable. Implementing classes must specify a unique yaml_tag class attribute, e.g. ``yaml_tag = "!Serializable"`` """ @serializable_init def __init__(self) -> None: """ Initialize class, including allocation of DyNet parameters if needed. The __init__() must always be annotated with @serializable_init. It's arguments are exactly those that can be specified in a YAML config file. If the argument values are Serializable, they are initialized before being passed to this class. The order of the arguments defined here determines in what order children are initialized, which may be important when there are dependencies between children. """ # attributes that are in the YAML file (never change this manually, use Serializable.save_processed_arg() instead) self.serialize_params = {}
[docs] def shared_params(self) -> List[Set[Union[str,'Path']]]: """ Return the shared parameters of this Serializable class. This can be overwritten to specify what parameters of this component and its subcomponents are shared. Parameter sharing is performed before any components are initialized, and can therefore only include basic data types that are already present in the YAML file (e.g. # dimensions, etc.) Sharing is performed if at least one parameter is specified and multiple shared parameters don't conflict. In case of conflict a warning is printed, and no sharing is performed. The ordering of shared parameters is irrelevant. Note also that if a submodule is replaced by a reference, its shared parameters are ignored. Returns: objects referencing params of this component or a subcompononent e.g.:: return [set([".input_dim", ".sub_module.input_dim", ".submodules_list.0.input_dim"])] """ return []
[docs] def save_processed_arg(self, key: str, val: Any) -> None: """ Save a new value for an init argument (call from within ``__init__()``). Normally, the serialization mechanism makes sure that the same arguments are passed when creating the class initially based on a config file, and when loading it from a saved model. This method can be called from inside ``__init__()`` to save a new value that will be passed when loading the saved model. This can be useful when one doesn't want to recompute something every time (like a vocab) or when something has been passed via implicit referencing which might yield inconsistent result when loading the model to assemble a new model of different structure. Args: key: name of property, must match an argument of ``__init__()`` val: new value; a :class:`Serializable` or basic Python type or list or dict of these """ if not hasattr(self, "serialize_params"): self.serialize_params = {} if key!="xnmt_subcol_name" and key not in _get_init_args_defaults(self): raise ValueError(f"{key} is not an init argument of {self}") self.serialize_params[key] = val
[docs] def add_serializable_component(self, name: str, passed: Any, create_fct: Callable[[], Any]) -> Any: """ Create a :class:`Serializable` component, or a container component with several :class:`Serializable`-s. :class:`Serializable` sub-components should always be created using this helper to make sure DyNet parameters are assigned properly and serialization works properly. The components must also be accepted as init arguments, defaulting to ``None``. The helper makes sure that components are only created if ``None`` is passed, otherwise the passed component is reused. The idiom for using this for an argument named ``my_comp`` would be:: def __init__(self, my_comp=None, other_args, ...): ... my_comp = self.add_serializable_component("my_comp", my_comp, lambda: SomeSerializable(other_args)) # now, do something with my_comp ... Args: name: name of the object passed: object as passed in the constructor. If ``None``, will be created using create_fct. create_fct: a callable with no arguments that returns a :class:`Serializable` or a collection of :class:`Serializable`-s. When loading a saved model, this same object will be passed via the ``passed`` argument, and ``create_fct`` is not invoked. Returns: reused or newly created object(s). """ if passed is None: initialized = create_fct() self.save_processed_arg(name, initialized) return initialized else: return passed
def __repr__(self): if getattr(self, "_is_bare", False): return f"bare({self.__class__.__name__}{self._bare_kwargs if self._bare_kwargs else ''})" else: return f"{self.__class__.__name__}@{id(self)}"
[docs]class UninitializedYamlObject(object): """ Wrapper class to indicate an object created by the YAML parser that still needs initialization. Args: data: uninitialized object """ def __init__(self, data: Any) -> None: if isinstance(data, UninitializedYamlObject): raise AssertionError self.data = data def get(self, key: str, default: Any) -> Any: return self.data.get(key, default)
T = TypeVar('T')
[docs]def bare(class_type: Type[T], **kwargs: Any) -> T: """ Create an uninitialized object of arbitrary type. This is useful to specify XNMT components as default arguments. ``__init__()`` commonly requires DyNet parameters, component referencing, etc., which are not yet set up at the time the default arguments are loaded. In this case, a bare class can be specified with the desired arguments, and will be properly initialized when passed as arguments into a component. Args: class_type: class type (must be a subclass of :class:`Serializable`) kwargs: will be passed to class's ``__init__()`` Returns: uninitialized object """ obj = class_type.__new__(class_type) assert isinstance(obj, Serializable) for key, val in kwargs.items(): setattr(obj, key, val) setattr(obj, "_is_bare", True) setattr(obj, "_bare_kwargs", kwargs) return obj
[docs]class Ref(Serializable): """ A reference to somewhere in the component hierarchy. Components can be referenced by path or by name. Args: path: reference by path name: reference by name. The name refers to a unique ``_xnmt_id`` property that must be set in exactly one component. """ yaml_tag = "!Ref" NO_DEFAULT = 1928437192847 @serializable_init def __init__(self, path: Union[None, 'Path', str] = None, name: Optional[str] = None, default: Any = NO_DEFAULT) -> None: if name is not None and path is not None: raise ValueError(f"Ref cannot be initialized with both a name and a path ({name} / {path})") if isinstance(path, str): path = Path(path) self.name = name self.path = path self.default = default self.serialize_params = {'name': name} if name else {'path': str(path)}
[docs] def get_name(self) -> str: """Return name, or ``None`` if this is not a named reference""" return getattr(self, "name", None)
[docs] def get_path(self) -> Optional['Path']: """Return path, or ``None`` if this is a named reference""" if getattr(self, "path", None): if isinstance(self.path, str): self.path = Path(self.path) return self.path return None
[docs] def is_required(self) -> bool: """Return ``True`` iff there exists no default value and it is mandatory that this reference be resolved.""" return getattr(self, "default", Ref.NO_DEFAULT) == Ref.NO_DEFAULT
[docs] def get_default(self) -> Any: """Return default value, or ``Ref.NO_DEFAULT`` if no default value is set (i.e., this is a required reference).""" return getattr(self, "default", None)
def __str__(self): default_str = f", default={self.default}" if getattr(self, "default", Ref.NO_DEFAULT) != Ref.NO_DEFAULT else "" if self.get_name(): return f"Ref(name={self.get_name()}{default_str})" else: return f"Ref(path={self.get_path()}{default_str})" def __repr__(self): return str(self)
[docs] def resolve_path(self, named_paths: Dict[str, 'Path']) -> 'Path': """Get path, resolving paths properly in case this is a named reference.""" if self.get_path(): if isinstance(self.get_path(), str): # need to do this here, because the initializer is never called when # Ref objects are specified in the YAML file self.path = Path(self.get_path()) return self.path elif self.get_name() in named_paths: return named_paths[self.get_name()] else: raise ValueError(f"Could not resolve path of reference {self}")
[docs]class Path(object): """ A relative or absolute path in the component hierarchy. Paths are immutable: Operations that change the path always return a new Path object. Args: path_str: path string, with period ``.`` as separator. If prefixed by ``.``, marks a relative path, otherwise absolute. """ def __init__(self, path_str: str = "") -> None: if (len(path_str) > 1 and path_str[-1] == "." and path_str[-2] != ".") \ or ".." in path_str.strip("."): raise ValueError(f"'{path_str}' is not a valid path string") self.path_str = path_str
[docs] def append(self, link: str) -> 'Path': """ Return a new path by appending a link. Args: link: link to append Returns: new path """ if not link or "." in link: raise ValueError(f"'{link}' is not a valid link") if len(self.path_str.strip(".")) == 0: return Path(f"{self.path_str}{link}") else: return Path(f"{self.path_str}.{link}")
[docs] def add_path(self, path_to_add: 'Path') -> 'Path': """ Concatenates a path Args: path_to_add: path to concatenate Returns: concatenated path """ if path_to_add.is_relative_path(): raise NotImplementedError("add_path() is not implemented for relative paths.") if len(self.path_str.strip(".")) == 0 or len(path_to_add.path_str) == 0: return Path(f"{self.path_str}{path_to_add.path_str}") else: return Path(f"{self.path_str}.{path_to_add.path_str}")
def __str__(self): return self.path_str def __repr__(self): return self.path_str def is_relative_path(self) -> bool: return self.path_str.startswith(".") def get_absolute(self, rel_to: 'Path') -> 'Path': if rel_to.is_relative_path(): raise ValueError("rel_to must be an absolute path!") if self.is_relative_path(): num_up = len(self.path_str) - len(self.path_str.strip(".")) - 1 for _ in range(num_up): rel_to = rel_to.parent() s = self.path_str.strip(".") if len(s) > 0: for link in s.split("."): rel_to = rel_to.append(link) return rel_to else: return self def descend_one(self) -> 'Path': if self.is_relative_path() or len(self) == 0: raise ValueError(f"Can't call descend_one() on path {self.path_str}") return Path(".".join(self.path_str.split(".")[1:])) def __len__(self): if self.is_relative_path(): raise ValueError(f"Can't call __len__() on path {self.path_str}") if len(self.path_str) == 0: return 0 return len(self.path_str.split(".")) def __getitem__(self, key): if self.is_relative_path(): raise ValueError(f"Can't call __getitem__() on path {self.path_str}") if isinstance(key, slice): _, _, step = key.indices(len(self)) if step is not None and step != 1: raise ValueError(f"step must be 1, found {step}") return Path(".".join(self.path_str.split(".")[key])) else: return self.path_str.split(".")[key] def parent(self) -> 'Path': if len(self.path_str.strip(".")) == 0: raise ValueError(f"Path '{self.path_str}' has no parent") else: spl = self.path_str.split(".")[:-1] if '.'.join(spl) == "" and self.path_str.startswith("."): return Path(".") else: return Path(".".join(spl)) def __hash__(self): return hash(self.path_str) def __eq__(self, other): if isinstance(other, Path): return self.path_str == other.path_str else: return False def ancestors(self) -> Set['Path']: a = self ret = {a} while len(a.path_str.strip(".")) > 0: a = a.parent() ret.add(a) return ret
[docs]class Repeat(Serializable): """ A special object that is replaced by a list of components with identical configuration but not with shared params. This can be specified anywhere in the config hierarchy where normally a list is expected. A common use case is a multi-layer neural architecture, where layer configurations are repeated many times. It is replaced in the preloader and cannot be instantiated directly. """ yaml_tag = "!Repeat" @serializable_init def __init__(self, times: numbers.Integral, content: Any) -> None: self.times = times self.content = content raise ValueError("Repeat cannot be instantiated")
_subcol_rand = random.Random() def _generate_subcol_name(subcol_owner): rand_bits = _subcol_rand.getrandbits(32) rand_hex = "%008x" % rand_bits return f"{type(subcol_owner).__name__}.{rand_hex}" _reserved_arg_names = ["_xnmt_id", "yaml_path", "serialize_params", "init_params", "kwargs", "self", "xnmt_subcol_name", "init_completed"] def _get_init_args_defaults(obj): return inspect.signature(obj.__init__).parameters def _check_serializable_args_valid(node): base_arg_names = map(lambda x: x[0], inspect.getmembers(yaml.YAMLObject)) class_param_names = [x[0] for x in inspect.getmembers(node.__class__)] init_args = _get_init_args_defaults(node) items = {key: val for (key, val) in inspect.getmembers(node)} for name in items: if name in base_arg_names or name in class_param_names: continue if name.startswith("_") or name in _reserved_arg_names: continue if name not in init_args: raise ValueError( f"'{name}' is not a accepted argument of {type(node).__name__}.__init__(). Valid are {list(init_args.keys())}") @singledispatch def _name_serializable_children(node): return _name_children(node, include_reserved=False) @_name_serializable_children.register(Serializable) def _name_serializable_children_serializable(node): return getattr(node, "serialize_params", {}).items() @singledispatch def _name_children(node, include_reserved): return [] @_name_children.register(Serializable) def _name_children_serializable(node, include_reserved): """ Returns the specified arguments in the order they appear in the corresponding ``__init__()`` """ init_args = list(_get_init_args_defaults(node).keys()) if include_reserved: init_args += [n for n in _reserved_arg_names if not n in init_args] items = {key: val for (key, val) in inspect.getmembers(node)} ret = [] for name in init_args: if name in items: val = items[name] ret.append((name, val)) return ret @_name_children.register(dict) def _name_children_dict(node, include_reserved): return node.items() @_name_children.register(list) def _name_children_list(node, include_reserved): return [(str(n), l) for n, l in enumerate(node)] @_name_children.register(tuple) def _name_children_tuple(node, include_reserved): raise ValueError(f"Tuples are not serializable, use a list instead. Found this tuple: {node}.") @singledispatch def _get_child(node, name): if not hasattr(node, name): raise PathError(f"{node} has no child named {name}") return getattr(node, name) @_get_child.register(list) def _get_child_list(node, name): try: name = int(name) except: raise PathError(f"{node} has no child named {name} (integer expected)") if not 0 <= name < len(node): raise PathError(f"{node} has no child named {name} (index error)") return node[int(name)] @_get_child.register(dict) def _get_child_dict(node, name): if not name in node.keys(): raise PathError(f"{node} has no child named {name} (key error)") return node[name] @_get_child.register(Serializable) def _get_child_serializable(node, name): # if hasattr(node, "serialize_params"): # return _get_child(node.serialize_params, name) # else: if not hasattr(node, name): raise PathError(f"{node} has no child named {name}") return getattr(node, name) @singledispatch def _set_child(node, name, val): pass @_set_child.register(Serializable) def _set_child_serializable(node, name, val): setattr(node, name, val) @_set_child.register(list) def _set_child_list(node, name, val): if name == "append": name = len(node) try: name = int(name) except: raise PathError(f"{node} has no child named {name} (integer expected)") if not 0 <= name < len(node)+1: raise PathError(f"{node} has no child named {name} (index error)") if name == len(node): node.append(val) else: node[int(name)] = val @_set_child.register(dict) def _set_child_dict(node, name, val): node[name] = val def _redirect_path_untested(path, root, cur_node=None): # note: this might become useful in the future, but is not carefully tested, use with care if cur_node is None: cur_node = root if len(path) == 0: if isinstance(cur_node, Ref): return cur_node.get_path() return path elif isinstance(cur_node, Ref): assert not cur_node.get_path().is_relative_path() return _redirect_path_untested(cur_node.get_path(), root, _get_descendant(root, cur_node.get_path())) else: try: return path[:1].add_path(_redirect_path_untested(path.descend_one(), root, _get_child(cur_node, path[0]))) except PathError: # child does not exist return path def _get_descendant(node, path, redirect=False): if len(path) == 0: return node elif redirect and isinstance(node, Ref): node_path = node.get_path() if isinstance(node_path, str): node_path = Path(node_path) return Ref(node_path.add_path(path), default=node.get_default()) else: return _get_descendant(_get_child(node, path[0]), path.descend_one(), redirect=redirect) def _set_descendant(root, path, val): if len(path) == 0: raise ValueError("path was empty") elif len(path) == 1: _set_child(root, path[0], val) else: _set_descendant(_get_child(root, path[0]), path.descend_one(), val) class _TraversalOrder(IntEnum): ROOT_FIRST = auto() ROOT_LAST = auto() def _traverse_tree(node, traversal_order=_TraversalOrder.ROOT_FIRST, path_to_node=Path(), include_root=True): """ For each node in the tree, yield a (path, node) tuple """ if include_root and traversal_order == _TraversalOrder.ROOT_FIRST: yield path_to_node, node for child_name, child in _name_children(node, include_reserved=False): yield from _traverse_tree(child, traversal_order, path_to_node.append(child_name)) if include_root and traversal_order == _TraversalOrder.ROOT_LAST: yield path_to_node, node def _traverse_serializable(root, path_to_node=Path()): yield path_to_node, root for child_name, child in _name_serializable_children(root): yield from _traverse_serializable(child, path_to_node.append(child_name)) def _traverse_serializable_breadth_first(root): all_nodes = [(path, node) for (path, node) in _traverse_serializable(root)] all_nodes = [item[1] for item in sorted(enumerate(all_nodes), key=lambda x: (len(x[1][0]), x[0]))] return iter(all_nodes) def _traverse_tree_deep(root, cur_node, traversal_order=_TraversalOrder.ROOT_FIRST, path_to_node=Path(), named_paths=None, past_visits=set()): """ Traverse the tree and descend into references. The returned path is that of the resolved reference. Args: root (Serializable): cur_node (Serializable): traversal_order (_TraversalOrder): path_to_node (Path): named_paths (dict): past_visits (set): """ # prevent infinite recursion: if named_paths is None: named_paths = {} cur_call_sig = (id(root), id(cur_node), path_to_node) if cur_call_sig in past_visits: return past_visits = set(past_visits) past_visits.add(cur_call_sig) if traversal_order == _TraversalOrder.ROOT_FIRST: yield path_to_node, cur_node if isinstance(cur_node, Ref): resolved_path = cur_node.resolve_path(named_paths) try: yield from _traverse_tree_deep(root, _get_descendant(root, resolved_path), traversal_order, resolved_path, named_paths, past_visits=past_visits) except PathError: pass else: for child_name, child in _name_children(cur_node, include_reserved=False): yield from _traverse_tree_deep(root, child, traversal_order, path_to_node.append(child_name), named_paths, past_visits=past_visits) if traversal_order == _TraversalOrder.ROOT_LAST: yield path_to_node, cur_node def _traverse_tree_deep_once(root, cur_node, traversal_order=_TraversalOrder.ROOT_FIRST, path_to_node=Path(), named_paths=None): """ Calls _traverse_tree_deep, but skips over nodes that have been visited before (can happen because we're descending into references). """ if named_paths is None: named_paths = {} yielded_paths = set() for path, node in _traverse_tree_deep(root, cur_node, traversal_order, path_to_node, named_paths): if not (path.ancestors() & yielded_paths): yielded_paths.add(path) yield (path, node) def _get_named_paths(root): d = {} for path, node in _traverse_tree(root): if "_xnmt_id" in [name for (name, _) in _name_children(node, include_reserved=True)]: xnmt_id = _get_child(node, "_xnmt_id") if xnmt_id in d: raise ValueError(f"_xnmt_id {xnmt_id} was specified multiple times!") d[xnmt_id] = path return d
[docs]class PathError(Exception): def __init__(self, message: str) -> None: super().__init__(message)
[docs]class SavedFormatString(str, Serializable): yaml_tag = "!SavedFormatString" @serializable_init def __init__(self, value: str, unformatted_value: str) -> None: self.unformatted_value = unformatted_value self.value = value
[docs]class FormatString(str, yaml.YAMLObject): """ Used to handle the ``{EXP}`` string formatting syntax. When passed around it will appear like the properly resolved string, but writing it back to YAML will use original version containing ``{EXP}`` """ def __new__(cls, value: str, *args, **kwargs) -> 'FormatString': return super().__new__(cls, value) def __init__(self, value: str, serialize_as: str) -> None: self.value = value self.serialize_as = serialize_as
def _init_fs_representer(dumper, obj): return dumper.represent_mapping('!SavedFormatString', {"value":obj.value,"unformatted_value":obj.serialize_as}) # return dumper.represent_data(SavedFormatString(value=obj.value, unformatted_value=obj.serialize_as)) yaml.add_representer(FormatString, _init_fs_representer)
[docs]class RandomParam(yaml.YAMLObject): yaml_tag = '!RandomParam' def __init__(self, values: list) -> None: self.values = values def __repr__(self): return f"{self.__class__.__name__}(values={self.values})" def draw_value(self) -> Any: if not hasattr(self, 'drawn_value'): self.drawn_value = random.choice(self.values) return self.drawn_value
[docs]class LoadSerialized(Serializable): """ Load content from an external YAML file. This object points to an object in an external YAML file and will be replaced by the corresponding content by the YAMLPreloader. Args: filename: YAML file name to load from path: path inside the YAML file to load from, with ``.`` separators. Empty string denotes root. overwrite: allows overwriting parts of the loaded model with new content. A list of path/val dictionaries, where ``path`` is a path string relative to the loaded sub-object following the syntax of :class:`Path`, and ``val`` is a Yaml-serializable specifying the new content. E.g.:: [{"path" : "model.trainer", "val":AdamTrainer()}, {"path" : ..., "val":...}] It is possible to specify the path to point to a new key to a dictionary. If ``path`` points to a list, it's possible append to that list by using ``append_val`` instead of ``val``. """ yaml_tag = "!LoadSerialized" @serializable_init def __init__(self, filename: str, path: str = "", overwrite: Optional[List[Dict[str,Any]]] = None) -> None: if overwrite is None: overwrite = [] self.filename = filename self.path = path self.overwrite = overwrite @staticmethod def _check_wellformed(load_serialized): _check_serializable_args_valid(load_serialized) if hasattr(load_serialized, "overwrite"): if not isinstance(load_serialized.overwrite, list): raise ValueError(f"LoadSerialized.overwrite must be a list, found: {type(load_serialized.overwrite)}") for item in load_serialized.overwrite: if not isinstance(item, dict): raise ValueError(f"LoadSerialized.overwrite must be a list of dictionaries, found list item: {type(item)}") if item.keys() != {"path", "val"}: raise ValueError(f"Each overwrite item must have 'path', 'val' (and no other) keys. Found: {item.keys()}")
[docs]class YamlPreloader(object): """ Loads experiments from YAML and performs basic preparation, but does not initialize objects. Has the following responsibilities: * takes care of extracting individual experiments from a YAML file * replaces ``!LoadSerialized`` by loading the corresponding content * resolves kwargs syntax (items from a kwargs dictionary are moved to the owner where they become object attributes) * implements random search (draws proper random values when ``!RandomParam`` is encountered) * finds and replaces placeholder strings such as ``{EXP}``, ``{EXP_DIR}``, ``{GIT_REV}``, and ``{PID}`` * copies bare default arguments into the corresponding objects where appropriate. Typically, :meth:`initialize_object` would be invoked by passing the result from the ``YamlPreloader``. """
[docs] @staticmethod def experiment_names_from_file(filename:str) -> List[str]: """Return list of experiment names. Args: filename: path to YAML file Returns: experiment names occuring in the given file in lexicographic order. """ try: with open(filename) as stream: experiments = yaml.load(stream) except IOError as e: raise RuntimeError(f"Could not read configuration file {filename}: {e}") except yaml.constructor.ConstructorError: logger.error( "for proper deserialization of a class object, make sure the class is a subclass of " "xnmt.serialize.serializable.Serializable, specifies a proper yaml_tag with leading '!', and its module is " "imported under xnmt/__init__.py") raise if isinstance(experiments, dict): if "defaults" in experiments: del experiments["defaults"] return sorted(experiments.keys()) elif isinstance(experiments, list): exp_names = [] for exp in experiments: if not hasattr(exp, "name"): raise ValueError("Encountered unnamed experiment.") if exp.name != "default": exp_names.append(exp.name) if len(exp_names) != len(set(exp_names)): raise ValueError(f"Found duplicate experiment names: {exp_names}.") return exp_names else: if experiments.__class__.__name__ != "Experiment": raise TypeError(f"Top level of config file must be a single Experiment or a list or dict of experiments." f"Found: {experiments} of type {type(experiments)}.") if not hasattr(experiments, "name"): raise ValueError("Encountered unnamed experiment.") return [experiments.name]
[docs] @staticmethod def preload_experiment_from_file(filename: str, exp_name: str, resume: bool = False) -> UninitializedYamlObject: """Preload experiment from YAML file. Args: filename: YAML config file name exp_name: experiment name to load resume: set to True if we are loading a saved model file directly and want to restore all formatted strings. Returns: Preloaded but uninitialized object. """ try: with open(filename) as stream: config = yaml.load(stream) except IOError as e: raise RuntimeError(f"Could not read configuration file {filename}: {e}") if isinstance(config, dict): experiment = config[exp_name] if getattr(experiment, "name", exp_name) != exp_name: raise ValueError(f"Inconsistent experiment name '{exp_name}' / '{experiment.name}'") if not isinstance(experiment, LoadSerialized): experiment.name = exp_name elif isinstance(config, list): experiment = None for exp in config: if not hasattr(exp, "name"): raise ValueError("Encountered unnamed experiment.") if exp.name==exp_name: experiment = exp if exp is None: raise ValueError(f"No experiment of name '{exp_name}' exists.") else: experiment = config if not hasattr(experiment, "name"): raise ValueError("Encountered unnamed experiment.") if not isinstance(experiment, LoadSerialized): if experiment.name != exp_name: raise ValueError(f"No experiment of name '{exp_name}' exists.") return YamlPreloader.preload_obj(experiment, exp_name=exp_name, exp_dir=os.path.dirname(filename) or ".", resume=resume)
[docs] @staticmethod def preload_obj(root: Any, exp_name: str, exp_dir: str, resume: bool = False) -> UninitializedYamlObject: """Preload a given object. Preloading a given object, usually an :class:`xnmt.experiment.Experiment` or :class:`LoadSerialized` object as parsed by pyyaml, includes replacing ``!LoadSerialized``, resolving ``kwargs`` syntax, and instantiating random search. Args: root: object to preload exp_name: experiment name, needed to replace ``{EXP}`` exp_dir: directory of the corresponding config file, needed to replace ``{EXP_DIR}`` resume: if True, keep the formatted strings, e.g. set ``{EXP}`` to the value of the previous run if possible Returns: Preloaded but uninitialized object. """ for _, node in _traverse_tree(root): if isinstance(node, Serializable): YamlPreloader._resolve_kwargs(node) YamlPreloader._copy_duplicate_components(root) # sometimes duplicate objects occur with yaml.load() placeholders = {"EXP": exp_name, "PID": os.getpid(), "EXP_DIR": exp_dir, "GIT_REV": tee.get_git_revision()} # do this both before and after resolving !LoadSerialized root = YamlPreloader._remove_saved_format_strings(root, keep_value=resume) YamlPreloader._format_strings(root, placeholders) root = YamlPreloader._load_serialized(root) random_search_report = YamlPreloader._instantiate_random_search(root) if random_search_report: setattr(root, 'random_search_report', random_search_report) root = YamlPreloader._resolve_repeat(root) # if arguments were not given in the YAML file and are set to a bare(Serializable) by default, copy the bare object # into the object hierarchy so it can be used w/ param sharing etc. YamlPreloader._resolve_bare_default_args(root) # do this both before and after resolving !LoadSerialized root = YamlPreloader._remove_saved_format_strings(root, keep_value=resume) YamlPreloader._format_strings(root, placeholders) return UninitializedYamlObject(root)
@staticmethod def _load_serialized(root: Any) -> Any: for path, node in _traverse_tree(root, traversal_order=_TraversalOrder.ROOT_LAST): if isinstance(node, LoadSerialized): LoadSerialized._check_wellformed(node) try: with open(node.filename) as stream: loaded_root = yaml.load(stream) except IOError as e: raise RuntimeError(f"Could not read configuration file {node.filename}: {e}") if os.path.isdir(f"{node.filename}.data"): param_collections.ParamManager.add_load_path(f"{node.filename}.data") cur_path = Path(getattr(node, "path", "")) for _ in range(10): # follow references loaded_trg = _get_descendant(loaded_root, cur_path, redirect=True) if isinstance(loaded_trg, Ref): cur_path = loaded_trg.get_path() else: break found_outside_ref = True self_inserted_ref_ids = set() while found_outside_ref: found_outside_ref = False named_paths = _get_named_paths(loaded_root) replaced_paths = {} for sub_path, sub_node in _traverse_tree(loaded_trg, path_to_node=cur_path): if isinstance(sub_node, Ref) and not id(sub_node) in self_inserted_ref_ids: referenced_path = sub_node.resolve_path(named_paths) if referenced_path.is_relative_path(): raise NotImplementedError("Handling of relative paths with LoadSerialized is not yet implemented.") if referenced_path in replaced_paths: new_ref = Ref(replaced_paths[referenced_path], default=sub_node.get_default()) _set_descendant(loaded_trg, sub_path[len(cur_path):], new_ref) self_inserted_ref_ids.add(id(new_ref)) # if outside node: elif not str(referenced_path).startswith(str(cur_path)): found_outside_ref = True referenced_obj = _get_descendant(loaded_root, referenced_path) _set_descendant(loaded_trg, sub_path[len(cur_path):], referenced_obj) # replaced_paths[referenced_path] = sub_path replaced_paths[referenced_path] = path.add_path(sub_path[len(cur_path):]) else: new_ref = Ref(path.add_path(referenced_path[len(cur_path):]), default=sub_node.get_default()) _set_descendant(loaded_trg, sub_path[len(cur_path):], new_ref) self_inserted_ref_ids.add(id(new_ref)) for d in getattr(node, "overwrite", []): overwrite_path = Path(d["path"]) _set_descendant(loaded_trg, overwrite_path, d["val"]) if len(path) == 0: root = loaded_trg else: _set_descendant(root, path, loaded_trg) return root @staticmethod def _copy_duplicate_components(root): obj_ids = set() for path, node in _traverse_tree(root, _TraversalOrder.ROOT_LAST): if isinstance(node, (list, dict, Serializable)): if id(node) in obj_ids: _set_descendant(root, path, copy.deepcopy(node)) obj_ids.add(id(node)) @staticmethod def _resolve_kwargs(obj: Any) -> None: """ If obj has a kwargs attribute (dictionary), set the dictionary items as attributes of the object via setattr (asserting that there are no collisions). """ if hasattr(obj, "kwargs"): for k, v in obj.kwargs.items(): if hasattr(obj, k): raise ValueError(f"kwargs '{str(k)}' already specified as class member for object '{str(obj)}'") setattr(obj, k, v) delattr(obj, "kwargs") @staticmethod def _instantiate_random_search(experiment): # TODO: this should probably be refactored: pull out of persistence.py and generalize so other things like # grid search and bayesian optimization can be supported param_report = {} initialized_random_params = {} for path, v in _traverse_tree(experiment): if isinstance(v, RandomParam): if hasattr(v, "_xnmt_id") and v._xnmt_id in initialized_random_params: v = initialized_random_params[v._xnmt_id] v = v.draw_value() if hasattr(v, "_xnmt_id"): initialized_random_params[v._xnmt_id] = v _set_descendant(experiment, path, v) param_report[path] = v return param_report @staticmethod def _resolve_repeat(root): for path, node in _traverse_tree(root, traversal_order=_TraversalOrder.ROOT_LAST): if isinstance(node, Repeat): expanded = [] for _ in range(node.times): expanded.append(copy.deepcopy(node.content)) if len(path) == 0: root = expanded else: _set_descendant(root, path, expanded) return root @staticmethod def _remove_saved_format_strings(root, keep_value=False): for path, node in _traverse_tree(root, traversal_order=_TraversalOrder.ROOT_LAST): if isinstance(node, SavedFormatString): replace_by = node.value if keep_value else node.unformatted_value if len(path) == 0: root = replace_by else: _set_descendant(root, path, replace_by) return root @staticmethod def _resolve_bare_default_args(root: Any) -> None: for path, node in _traverse_tree(root): if isinstance(node, Serializable): init_args_defaults = _get_init_args_defaults(node) for expected_arg in init_args_defaults: if not expected_arg in [x[0] for x in _name_children(node, include_reserved=False)]: arg_default = init_args_defaults[expected_arg].default if isinstance(arg_default, Serializable) and not isinstance(arg_default, Ref): if not getattr(arg_default, "_is_bare", False): raise ValueError( f"only Serializables created via bare(SerializableSubtype) are permitted as default arguments; " f"found a fully initialized Serializable: {arg_default} at {path}") YamlPreloader._resolve_bare_default_args(arg_default) # apply recursively setattr(node, expected_arg, copy.deepcopy(arg_default)) @staticmethod def _format_strings(root: Any, format_dict: Dict[str, str]) -> None: """ - replaces strings containing ``{EXP}`` and other supported args - also checks if there are default arguments for which no arguments are set and instantiates them with replaced ``{EXP}`` if applicable """ try: format_dict.update(root.exp_global.placeholders) except AttributeError: pass for path, node in _traverse_tree(root): if isinstance(node, str): try: formatted = node.format(**format_dict) except (ValueError, KeyError, IndexError): # will occur e.g. if a vocab entry contains a curly bracket formatted = node if node != formatted: _set_descendant(root, path, FormatString(formatted, node)) elif isinstance(node, Serializable): init_args_defaults = _get_init_args_defaults(node) for expected_arg in init_args_defaults: if not expected_arg in [x[0] for x in _name_children(node, include_reserved=False)]: arg_default = init_args_defaults[expected_arg].default if isinstance(arg_default, str): try: formatted = arg_default.format(**format_dict) except (ValueError, KeyError): # will occur e.g. if a vocab entry contains a curly bracket formatted = arg_default if arg_default != formatted: setattr(node, expected_arg, FormatString(formatted, arg_default))
class _YamlDeserializer(object): def __init__(self): self.has_been_called = False def initialize_if_needed(self, obj: Union[Serializable,UninitializedYamlObject]) -> Serializable: """ Initialize if obj has not yet been initialized. Note: make sure to always create a new ``_YamlDeserializer`` before calling this, e.g. using ``_YamlDeserializer().initialize_object()`` Args: obj: object to be potentially serialized Returns: initialized object """ if self.is_initialized(obj): return obj else: return self.initialize_object(deserialized_yaml_wrapper=obj) @staticmethod def is_initialized(obj: Union[Serializable,UninitializedYamlObject]) -> bool: """ Returns: ``True`` if a serializable object's ``__init__()`` has been invoked (either programmatically or through YAML deserialization). ``False`` if ``__init__()`` has not been invoked, i.e. the object has been produced by the YAML parser but is not ready to use. """ return type(obj) != UninitializedYamlObject def initialize_object(self, deserialized_yaml_wrapper: Any) -> Any: """ Initializes a hierarchy of deserialized YAML objects. Note: make sure to always create a new ``_YamlDeserializer`` before calling this, e.g. using ``_YamlDeserializer().initialize_object()`` Args: deserialized_yaml_wrapper: deserialized YAML data inside a :class:`UninitializedYamlObject` wrapper (classes are resolved and class members set, but ``__init__()`` has not been called at this point) Returns: the appropriate object, with properly shared parameters and ``__init__()`` having been invoked """ assert not self.has_been_called self.has_been_called = True if self.is_initialized(deserialized_yaml_wrapper): raise AssertionError() # make a copy to avoid side effects self.deserialized_yaml = copy.deepcopy(deserialized_yaml_wrapper.data) # make sure only arguments accepted by the Serializable derivatives' __init__() methods were passed self.check_args(self.deserialized_yaml) # if arguments were not given in the YAML file and are set to a bare(Serializable) by default, copy the bare object into the object hierarchy so it can be used w/ param sharing etc. YamlPreloader._resolve_bare_default_args(self.deserialized_yaml) self.named_paths = _get_named_paths(self.deserialized_yaml) # if arguments were not given in the YAML file and are set to a Ref by default, copy this Ref into the object structure so that it can be properly resolved in a subsequent step self.resolve_ref_default_args(self.deserialized_yaml) # if references point to places that are not specified explicitly in the YAML file, but have given default arguments, substitute those default arguments self.create_referenced_default_args(self.deserialized_yaml) # apply sharing as requested by Serializable.shared_params() self.share_init_params_top_down(self.deserialized_yaml) # finally, initialize each component via __init__(**init_params), while properly resolving references initialized = self.init_components_bottom_up(self.deserialized_yaml) return initialized def check_args(self, root): for _, node in _traverse_tree(root): if isinstance(node, Serializable): _check_serializable_args_valid(node) def resolve_ref_default_args(self, root): for _, node in _traverse_tree(root): if isinstance(node, Serializable): init_args_defaults = _get_init_args_defaults(node) for expected_arg in init_args_defaults: if not expected_arg in [x[0] for x in _name_children(node, include_reserved=False)]: arg_default = copy.deepcopy(init_args_defaults[expected_arg].default) if isinstance(arg_default, Ref): setattr(node, expected_arg, arg_default) def create_referenced_default_args(self, root): for path, node in _traverse_tree(root): if isinstance(node, Ref): referenced_path = node.get_path() if not referenced_path: continue # skip named paths if isinstance(referenced_path, str): referenced_path = Path(referenced_path) give_up = False for ancestor in sorted(referenced_path.ancestors(), key = lambda x: len(x)): try: _get_descendant(root, ancestor) except PathError: try: ancestor_parent = _get_descendant(root, ancestor.parent()) if isinstance(ancestor_parent, Serializable): init_args_defaults = _get_init_args_defaults(ancestor_parent) if ancestor[-1] in init_args_defaults: referenced_arg_default = init_args_defaults[ancestor[-1]].default else: referenced_arg_default = inspect.Parameter.empty if referenced_arg_default != inspect.Parameter.empty: _set_descendant(root, ancestor, copy.deepcopy(referenced_arg_default)) else: give_up = True except PathError: pass if give_up: break def share_init_params_top_down(self, root): abs_shared_param_sets = [] for path, node in _traverse_tree(root): if isinstance(node, Serializable): for shared_param_set in node.shared_params(): shared_param_set = set(Path(p) if isinstance(p, str) else p for p in shared_param_set) abs_shared_param_set = set(p.get_absolute(path) for p in shared_param_set) added = False for prev_set in abs_shared_param_sets: if prev_set & abs_shared_param_set: prev_set |= abs_shared_param_set added = True break if not added: abs_shared_param_sets.append(abs_shared_param_set) for shared_param_set in abs_shared_param_sets: shared_val_choices = set() for shared_param_path in shared_param_set: try: new_shared_val = _get_descendant(root, shared_param_path) except PathError: continue for _, child_of_shared_param in _traverse_tree(new_shared_val, include_root=False): if isinstance(child_of_shared_param, Serializable): raise ValueError(f"{path} shared params {shared_param_set} contains Serializable sub-object {child_of_shared_param} which is not permitted") if not isinstance(new_shared_val, Ref): shared_val_choices.add(new_shared_val) if len(shared_val_choices)>1: logger.warning(f"inconsistent shared params at {path} for {shared_param_set}: {shared_val_choices}; Ignoring these shared parameters.") elif len(shared_val_choices)==1: for shared_param_path in shared_param_set: try: if shared_param_path[-1] in _get_init_args_defaults(_get_descendant(root, shared_param_path.parent())): _set_descendant(root, shared_param_path, list(shared_val_choices)[0]) except PathError: pass # can happen when the shared path contained a reference, which we don't follow to avoid unwanted effects def init_components_bottom_up(self, root): for path, node in _traverse_tree_deep_once(root, root, _TraversalOrder.ROOT_LAST, named_paths=self.named_paths): if isinstance(node, Serializable): if isinstance(node, Ref): hits_before = self.init_component.cache_info().hits try: resolved_path = node.resolve_path(self.named_paths) initialized_component = self.init_component(resolved_path) except PathError: if getattr(node, "default", Ref.NO_DEFAULT) == Ref.NO_DEFAULT: initialized_component = None else: initialized_component = copy.deepcopy(node.default) if self.init_component.cache_info().hits > hits_before: logger.debug(f"for {path}: reusing previously initialized {initialized_component}") else: initialized_component = self.init_component(path) if len(path)==0: root = initialized_component else: _set_descendant(root, path, initialized_component) return root def check_init_param_types(self, obj, init_params): for init_param_name in init_params: param_sig = _get_init_args_defaults(obj) if init_param_name in param_sig: annotated_type = param_sig[init_param_name].annotation if annotated_type != inspect.Parameter.empty: if not check_type(init_params[init_param_name], annotated_type): raise ValueError(f"type check failed for '{init_param_name}' argument of {obj}: expected {annotated_type}, received {init_params[init_param_name]} of type {type(init_params[init_param_name])}") @lru_cache(maxsize=None) def init_component(self, path): """ Args: path: path to uninitialized object Returns: initialized object; this method is cached, so multiple requests for the same path will return the exact same object """ obj = _get_descendant(self.deserialized_yaml, path) if not isinstance(obj, Serializable) or isinstance(obj, FormatString): return obj init_params = OrderedDict(_name_children(obj, include_reserved=False)) init_args = _get_init_args_defaults(obj) if "yaml_path" in init_args: init_params["yaml_path"] = path self.check_init_param_types(obj, init_params) with utils.ReportOnException({"yaml_path":path}): try: if hasattr(obj, "xnmt_subcol_name"): initialized_obj = obj.__class__(**init_params, xnmt_subcol_name=obj.xnmt_subcol_name) else: initialized_obj = obj.__class__(**init_params) logger.debug(f"initialized {path}: {obj.__class__.__name__}@{id(obj)}({dict(init_params)})"[:1000]) except TypeError as e: raise ComponentInitError(f"An error occurred when calling {type(obj).__name__}.__init__()\n" f" The following arguments were passed: {init_params}\n" f" The following arguments were expected: {init_args.keys()}\n" f" Current path: {path}\n" f" Error message: {e}") return initialized_obj def _resolve_serialize_refs(root): all_serializable = set() # for DyNet param check # gather all non-basic types (Serializable, list, dict) in the global dictionary xnmt.resolved_serialize_params for _, node in _traverse_serializable(root): if isinstance(node, Serializable): all_serializable.add(id(node)) if not hasattr(node, "serialize_params"): raise ValueError(f"Cannot serialize node that has no serialize_params attribute: {node}\n" "Did you forget to wrap the __init__() in @serializable_init ?") xnmt.resolved_serialize_params[id(node)] = node.serialize_params elif isinstance(node, collections.abc.MutableMapping): xnmt.resolved_serialize_params[id(node)] = dict(node) elif isinstance(node, collections.abc.MutableSequence): xnmt.resolved_serialize_params[id(node)] = list(node) if not set(id(o) for o in param_collections.ParamManager.param_col.all_subcol_owners) <= all_serializable: raise RuntimeError(f"Not all registered DyNet parameter collections written out. " f"Missing: {param_collections.ParamManager.param_col.all_subcol_owners - all_serializable}.\n" f"This indicates that potentially not all components adhere to the protocol of using " f"Serializable.add_serializable_component() for creating serializable sub-components.") refs_inserted_at = set() refs_inserted_to = set() for path_trg, node_trg in _traverse_serializable(root): # loop potential reference targets if not refs_inserted_at & path_trg.ancestors(): # skip target if it or its ancestor has already been replaced by a reference if isinstance(node_trg, Serializable): for path_src, node_src in _traverse_serializable(root): # loop potential nodes that should be replaced by a reference to the current target if not path_src in refs_inserted_to: # don't replace by reference if someone is pointing to this node already if path_src!=path_trg and node_src is node_trg: # don't reference to self # now we're ready to create a reference from node_src to node_trg (node_src will be replaced, node_trg remains unchanged) ref = Ref(path=path_trg) xnmt.resolved_serialize_params[id(ref)] = ref.serialize_params # make sure the reference itself can be properly serialized src_node_parent = _get_descendant(root, path_src.parent()) src_node_parent_serialize_params = xnmt.resolved_serialize_params[id(src_node_parent)] _set_descendant(src_node_parent_serialize_params, Path(path_src[-1]), ref) if isinstance(src_node_parent, (collections.abc.MutableMapping, collections.abc.MutableSequence)): assert isinstance(_get_descendant(root, path_src.parent().parent()), Serializable), \ "resolving references inside nested lists/dicts is not yet implemented" src_node_grandparent = _get_descendant(root, path_src.parent().parent()) src_node_parent_name = path_src[-2] xnmt.resolved_serialize_params[id(src_node_grandparent)][src_node_parent_name] = \ xnmt.resolved_serialize_params[id(src_node_parent)] refs_inserted_at.add(path_src) refs_inserted_to.add(path_trg) def _dump(ser_obj): assert len(xnmt.resolved_serialize_params)==0 _resolve_serialize_refs(ser_obj) ret = yaml.dump(ser_obj) xnmt.resolved_serialize_params.clear() return ret
[docs]def save_to_file(fname: str, mod: Any) -> None: """ Save a component hierarchy and corresponding DyNet parameter collection to disk. Args: fname: Filename to save to. mod: Component hierarchy. """ dirname = os.path.dirname(fname) if dirname and not os.path.exists(dirname): os.makedirs(dirname) with open(fname, 'w') as f: f.write(_dump(mod)) param_collections.ParamManager.param_col.save()
[docs]def initialize_if_needed(root: Union[Any, UninitializedYamlObject]) -> Any: """ Initialize if obj has not yet been initialized. This includes parameter sharing and resolving of references. Args: root: object to be potentially serialized Returns: initialized object """ return _YamlDeserializer().initialize_if_needed(root)
[docs]def initialize_object(root: UninitializedYamlObject) -> Any: """ Initialize an uninitialized object. This includes parameter sharing and resolving of references. Args: root: object to be serialized Returns: initialized object """ return _YamlDeserializer().initialize_object(root)
[docs]class ComponentInitError(Exception): pass
[docs]def check_type(obj, desired_type): """ Checks argument types using isinstance, or some custom logic if type hints from the 'typing' module are given. Regarding type hints, only a few major ones are supported. This should cover almost everything that would be expected in a YAML config file, but might miss a few special cases. For unsupported types, this function evaluates to True. Most notably, forward references such as 'SomeType' (with apostrophes around the type) are not supported. Note also that typing.Tuple is among the unsupported types because tuples aren't supported by the XNMT serializer. Args: obj: object whose type to check desired_type: desired type of obj Returns: False if types don't match or desired_type is unsupported, True otherwise. """ try: if isinstance(obj, desired_type): return True if isinstance(obj, Serializable): # handle some special issues, probably caused by inconsistent imports: if obj.__class__.__name__ == desired_type.__name__ or any( base.__name__ == desired_type.__name__ for base in obj.__class__.__bases__): return True return False except TypeError: if type(desired_type) == str: return True # don't support forward type references if desired_type.__class__.__name__ == "_Any": return True elif desired_type == type(None): return obj is None elif desired_type.__class__.__name__ == "_Union": return any( subtype.__class__.__name__ == "_ForwardRef" or check_type(obj, subtype) for subtype in desired_type.__args__) elif issubclass(desired_type.__class__, collections.abc.MutableMapping): if not isinstance(obj, collections.abc.MutableMapping): return False if desired_type.__args__: return (desired_type.__args__[0].__class__.__name__ == "_ForwardRef" or all( check_type(key, desired_type.__args__[0]) for key in obj.keys())) and ( desired_type.__args__[1].__class__.__name__ == "_ForwardRef" or all( check_type(val, desired_type.__args__[1]) for val in obj.values())) else: return True elif issubclass(desired_type.__class__, collections.abc.Sequence): if not isinstance(obj, collections.abc.Sequence): return False if desired_type.__args__ and desired_type.__args__[0].__class__.__name__ != "_ForwardRef": return all(check_type(item, desired_type.__args__[0]) for item in obj) else: return True elif desired_type.__class__.__name__ == "TupleMeta": if not isinstance(obj, tuple): return False if desired_type.__args__: if desired_type.__args__[-1] is ...: return desired_type.__args__[0].__class__.__name__ == "_ForwardRef" or check_type(obj[0], desired_type.__args__[0]) else: return len(obj) == len(desired_type.__args__) and all( desired_type.__args__[i].__class__.__name__ == "_ForwardRef" or check_type(obj[i], desired_type.__args__[i]) for i in range(len(obj))) else: return True return True # case of unsupported types: return True