diff --git a/platypush/plugins/tensorflow/__init__.py b/platypush/plugins/tensorflow/__init__.py index 83e3a8fe..55d39df7 100644 --- a/platypush/plugins/tensorflow/__init__.py +++ b/platypush/plugins/tensorflow/__init__.py @@ -55,7 +55,7 @@ class TensorflowPlugin(Plugin): _csv_extensions = ['csv', 'tsv'] _supported_data_file_extensions = [*_csv_extensions, *_numpy_extensions, *_image_extensions] - def __init__(self, workdir: str = os.path.join(Config.get('workdir'), 'tensorflow'), **kwargs): + def __init__(self, workdir: Optional[str] = None, **kwargs): """ :param workdir: Working directory for TensorFlow, where models will be stored (default: PLATYPUSH_WORKDIR/tensorflow). @@ -63,7 +63,9 @@ class TensorflowPlugin(Plugin): super().__init__(**kwargs) self.models: Dict[str, Model] = {} self._model_locks: Dict[str, threading.RLock()] = {} - self._work_dir = os.path.abspath(os.path.expanduser(workdir)) + self._work_dir = os.path.abspath(os.path.expanduser(workdir)) if workdir else \ + os.path.join(Config.get('workdir'), 'tensorflow') + self._models_dir = os.path.join(self._work_dir, 'models') os.makedirs(self._models_dir, mode=0o755, exist_ok=True)