@tf.function deftest(a, b): print('trace', a, b) return a, b
1 2 3 4 5
for batch in range(1, 10): test(np.zeros([batch, 2], dtype=np.float32), np.zeros([batch, 3], dtype=np.float32)) None2None3
trace Tensor("a:0", shape=(1, 2), dtype=float32) Tensor("b:0", shape=(1, 3), dtype=float32)
trace Tensor("a:0", shape=(2, 2), dtype=float32) Tensor("b:0", shape=(2, 3), dtype=float32)
trace Tensor("a:0", shape=(3, 2), dtype=float32) Tensor("b:0", shape=(3, 3), dtype=float32)
trace Tensor("a:0", shape=(4, 2), dtype=float32) Tensor("b:0", shape=(4, 3), dtype=float32)
trace Tensor("a:0", shape=(5, 2), dtype=float32) Tensor("b:0", shape=(5, 3), dtype=float32)
WARNING:tensorflow:5 out of the last 5 calls to <function test at 0x000002514EB061F8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
trace Tensor("a:0", shape=(6, 2), dtype=float32) Tensor("b:0", shape=(6, 3), dtype=float32)
WARNING:tensorflow:6 out of the last 6 calls to <function test at 0x000002514EB061F8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
trace Tensor("a:0", shape=(7, 2), dtype=float32) Tensor("b:0", shape=(7, 3), dtype=float32)
WARNING:tensorflow:7 out of the last 7 calls to <function test at 0x000002514EB061F8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
trace Tensor("a:0", shape=(8, 2), dtype=float32) Tensor("b:0", shape=(8, 3), dtype=float32)
WARNING:tensorflow:8 out of the last 8 calls to <function test at 0x000002514EB061F8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
trace Tensor("a:0", shape=(9, 2), dtype=float32) Tensor("b:0", shape=(9, 3), dtype=float32)
WARNING:tensorflow:9 out of the last 9 calls to <function test at 0x000002514EB061F8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
使用 input_signature
1 2 3 4 5
@tf.function(input_signature=(tf.TensorSpec(shape=[None, 2], dtype=tf.float32), tf.TensorSpec(shape=[None, 3], dtype=tf.float32))) deftest(a, b): print('trace', a, b) return a, b
1 2
for batch in range(10): test(np.zeros([batch, 2], dtype=np.float32), np.zeros([batch, 3], dtype=np.float32))
@tf.function(input_signature=(tf.TensorSpec(shape=[None, 2], dtype=tf.float32), tf.TensorSpec(shape=[None, 3], dtype=tf.float32))) deftest(a, b=None): print(a, b) return a, b
1 2
for batch in range(10): test(np.zeros([batch, 2], dtype=np.float32), np.zeros([batch, 3], dtype=np.float32))
for batch in range(10): test(np.zeros([batch, 2], dtype=np.float32))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in _convert_inputs_to_signature(inputs, input_signature, flat_input_signature)
2239 flatten_inputs[index] = ops.convert_to_tensor(
-> 2240 value, dtype_hint=spec.dtype)
2241 need_packing = True
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\framework\ops.py in convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types)
1313 if ret is None:
-> 1314 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
1315
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\framework\constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref)
316 _ = as_ref
--> 317 return constant(v, dtype=dtype, name=name)
318
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\framework\constant_op.py in constant(value, dtype, shape, name)
257 return _constant_impl(value, dtype, shape, name, verify_shape=False,
--> 258 allow_broadcast=True)
259
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\framework\constant_op.py in _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast)
265 if ctx.executing_eagerly():
--> 266 t = convert_to_eager_tensor(value, ctx, dtype)
267 if shape is None:
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\framework\constant_op.py in convert_to_eager_tensor(value, ctx, dtype)
95 ctx.ensure_initialized()
---> 96 return ops.EagerTensor(value, ctx.device_name, dtype)
97
ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-14-ea9826667fb9> in <module>
1 for batch in range(10):
----> 2 test(np.zeros([batch, 2], dtype=np.float32), None)
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\def_function.py in __call__(self, *args, **kwds)
566 xla_context.Exit()
567 else:
--> 568 result = self._call(*args, **kwds)
569
570 if tracing_count == self._get_tracing_count():
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\def_function.py in _call(self, *args, **kwds)
604 # In this case we have not created variables on the first call. So we can
605 # run the first trace but we should fail if variables are created.
--> 606 results = self._stateful_fn(*args, **kwds)
607 if self._created_variables:
608 raise ValueError("Creating variables on a non-first call to a function"
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in __call__(self, *args, **kwargs)
2360 """Calls a graph function specialized to the inputs."""
2361 with self._lock:
-> 2362 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
2363 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2364
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in _maybe_define_function(self, args, kwargs)
2659 if self.input_signature is None or args is not None or kwargs is not None:
2660 args, kwargs = self._function_spec.canonicalize_function_inputs(
-> 2661 *args, **kwargs)
2662
2663 cache_key = self._cache_key(args, kwargs)
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in canonicalize_function_inputs(self, *args, **kwargs)
2183 inputs,
2184 self._input_signature,
-> 2185 self._flat_input_signature)
2186 return inputs, {}
2187
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in _convert_inputs_to_signature(inputs, input_signature, flat_input_signature)
2244 "the Python function must be convertible to "
2245 "tensors:\n%s" %
-> 2246 format_error_message(inputs, input_signature))
2247
2248 if any(not spec.is_compatible_with(other) for spec, other in zip(
ValueError: When input_signature is provided, all inputs to the Python function must be convertible to tensors:
inputs: (
[],
None)
input_signature: (
TensorSpec(shape=(None, 2), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 3), dtype=tf.float32, name=None))
1 2 3 4
@tf.function(input_signature=(tf.TensorSpec(shape=[None, 2], dtype=tf.float32), None)) deftest(a, b=None): print(a, b) return a, b
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-18-c5e14b2bdf68> in <module>
----> 1 @tf.function(input_signature=(tf.TensorSpec(shape=[None, 2], dtype=tf.float32), None))
2 def test(a, b=None):
3 print(a, b)
4 return a, b
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\def_function.py in function(func, input_signature, autograph, experimental_implements, experimental_autograph_options, experimental_relax_shapes, experimental_compile)
1172 """
1173 if input_signature is not None:
-> 1174 function_lib.validate_signature(input_signature)
1175
1176 def decorated(inner_function):
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in validate_signature(signature)
2738 raise TypeError("Invalid input_signature {}; input_signature must be "
2739 "a possibly nested sequence of TensorSpec objects."
-> 2740 .format(signature))
2741
2742
TypeError: Invalid input_signature (TensorSpec(shape=(None, 2), dtype=tf.float32, name=None), None); input_signature must be a possibly nested sequence of TensorSpec objects.
1 2 3 4 5 6 7 8
# 可以不传参数为None的input_signature @tf.function(input_signature=(tf.TensorSpec(shape=[None, 2], dtype=tf.float32),)) deftest(a, b=None): print(a, b) return a, b
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-26-b9e2079756e4> in <module>
1 # 但不能传None
----> 2 test_c(tf.zeros([batch, 2], dtype=tf.float32), None)
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in __call__(self, *args, **kwargs)
1549 TypeError: For invalid positional/keyword argument combinations.
1550 """
-> 1551 return self._call_impl(args, kwargs)
1552
1553 def _call_impl(self, args, kwargs, cancellation_manager=None):
~\Miniconda3\envs\tf\lib\site-packages\tensorflow_core\python\eager\function.py in _call_impl(self, args, kwargs, cancellation_manager)
1568 "of {}), got {}. When calling a concrete function, positional "
1569 "arguments may not be bound to Tensors within nested structures."
-> 1570 ).format(self._num_positional_args, self._arg_keywords, args))
1571 args = list(args)
1572 for keyword in self._arg_keywords[len(args):]:
TypeError: Expected at most 1 positional arguments (and the rest keywords, of ['a']), got (<tf.Tensor: shape=(16, 2), dtype=float32, numpy=
array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float32)>, None). When calling a concrete function, positional arguments may not be bound to Tensors within nested structures.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# 将numpy类型转为Tensor,同时去除None参数 defnp_to_tensor(fn): defc(*args, **kwargs): return fn(*[tf.constant(k) ifnot isinstance(k, tf.Tensor) else k for k in args if k isnotNone], **{k: tf.constant(v) ifnot isinstance(v, tf.Tensor) else v for k, v in kwargs.items() if v isnotNone})
return c
@tf.function deftest(a, b=None): print(a, b) return a, b
test_c = test.get_concrete_function(a=tf.TensorSpec(shape=[None, 2], dtype=tf.float32)) test = np_to_tensor(test_c)