From f4dcf688f041c2a383df2072a59a399e54cb11ae Mon Sep 17 00:00:00 2001 From: Fabio Manganiello Date: Mon, 23 Mar 2020 01:10:59 +0100 Subject: [PATCH] Set default values for metrics for regression and networks --- platypush/plugins/tensorflow/__init__.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/platypush/plugins/tensorflow/__init__.py b/platypush/plugins/tensorflow/__init__.py index fb82d8a1e..83e3a8fe3 100644 --- a/platypush/plugins/tensorflow/__init__.py +++ b/platypush/plugins/tensorflow/__init__.py @@ -238,7 +238,7 @@ class TensorflowPlugin(Plugin): you could also pass a dictionary, such as ``metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}``. You can also pass a list ``(len = len(outputs))`` of lists of metrics such as ``metrics=[['accuracy'], ['accuracy', 'mse']]`` or - ``metrics=['accuracy', ['accuracy', 'mse']]``. + ``metrics=['accuracy', ['accuracy', 'mse']]``. Default: ``['accuracy']``. :param loss_weights: Optional list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model @@ -370,6 +370,9 @@ class TensorflowPlugin(Plugin): layer = self._layer_from_dict(layer.pop('type'), **layer) model.add(layer) + if not metrics: + metrics = ['accuracy'] + model.compile( optimizer=optimizer, loss=loss, @@ -433,7 +436,7 @@ class TensorflowPlugin(Plugin): you could also pass a dictionary, such as ``metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}``. You can also pass a list ``(len = len(outputs))`` of lists of metrics such as ``metrics=[['accuracy'], ['accuracy', 'mse']]`` or - ``metrics=['accuracy', ['accuracy', 'mse']]``. + ``metrics=['accuracy', ['accuracy', 'mse']]``. Default: ``['mae', 'mse']``. :param loss_weights: Optional list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model @@ -501,6 +504,9 @@ class TensorflowPlugin(Plugin): else: model.output_labels = [] + if not metrics: + metrics = ['mae', 'mse'] + model.compile( optimizer=optimizer, loss=loss,