Inside a tf.function or v1.Graph context it checks both the buildtime and
runtime shapes. This is stricter than tf.Tensor.set_shape which only
checks the buildtime shape.
For example, of loading images of a known size:
@tf.functiondefdecode_image(png):image=tf.image.decode_png(png,channels=3)# the `print` executes during tracing.print("Initial shape: ",image.shape)image=tf.ensure_shape(image,[28,28,3])print("Final shape: ",image.shape)returnimage
When tracing a function, no ops are being executed, shapes may be unknown.
See the Concrete Functions Guide
for details.
@tf.functiondefbad_decode_image(png):image=tf.image.decode_png(png,channels=3)# the `print` executes during tracing.print("Initial shape: ",image.shape)# BAD: forgot to use the returned tensor.tf.ensure_shape(image,[28,28,3])print("Final shape: ",image.shape)returnimage
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2021-02-18 UTC."],[],[],null,["# tf.ensure_shape\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|\n| [TensorFlow 1 version](/versions/r1.15/api_docs/python/tf/ensure_shape) | [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/python/ops/check_ops.py#L2219-L2338) |\n\nUpdates the shape of a tensor and checks at runtime that the shape holds.\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.ensure_shape`](https://www.tensorflow.org/api_docs/python/tf/ensure_shape)\n\n\u003cbr /\u003e\n\n tf.ensure_shape(\n x, shape, name=None\n )\n\n#### For example:\n\n @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])\n def f(tensor):\n return tf.ensure_shape(tensor, [3, 3])\n\n f(tf.zeros([3, 3])) # Passes\n \u003ctf.Tensor: shape=(3, 3), dtype=float32, numpy=\n array([[0., 0., 0.],\n [0., 0., 0.],\n [0., 0., 0.]], dtype=float32)\u003e\n f([1, 2, 3]) # fails\n Traceback (most recent call last):\n\n InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3].\n\nThe above example raises [`tf.errors.InvalidArgumentError`](../tf/errors/InvalidArgumentError),\nbecause the shape (3,) is not compatible with the shape (None, 3, 3)\n\nWith eager execution this is a shape assertion, that returns the input: \n\n x = tf.constant([1,2,3])\n print(x.shape)\n (3,)\n x = tf.ensure_shape(x, [3])\n x = tf.ensure_shape(x, [5])\n Traceback (most recent call last):\n\n tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not\n compatible with expected shape [5]. [Op:EnsureShape]\n\nInside a [`tf.function`](../tf/function) or [`v1.Graph`](../tf/Graph) context it checks both the buildtime and\nruntime shapes. This is stricter than [`tf.Tensor.set_shape`](../tf/Tensor#set_shape) which only\nchecks the buildtime shape.\n| **Note:** This differs from [`tf.Tensor.set_shape`](../tf/Tensor#set_shape) in that it sets the static shape of the resulting tensor and enforces it at runtime, raising an error if the tensor's runtime shape is incompatible with the specified shape. [`tf.Tensor.set_shape`](../tf/Tensor#set_shape) sets the static shape of the tensor without enforcing it at runtime, which may result in inconsistencies between the statically-known shape of tensors and the runtime value of tensors.\n\nFor example, of loading images of a known size: \n\n @tf.function\n def decode_image(png):\n image = tf.image.decode_png(png, channels=3)\n # the `print` executes during tracing.\n print(\"Initial shape: \", image.shape)\n image = tf.ensure_shape(image,[28, 28, 3])\n print(\"Final shape: \", image.shape)\n return image\n\nWhen tracing a function, no ops are being executed, shapes may be unknown.\nSee the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function)\nfor details. \n\n concrete_decode = decode_image.get_concrete_function(\n tf.TensorSpec([], dtype=tf.string))\n Initial shape: (None, None, 3)\n Final shape: (28, 28, 3)\n\n image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32)\n image = tf.cast(image,tf.uint8)\n png = tf.image.encode_png(image)\n image2 = concrete_decode(png)\n print(image2.shape)\n (28, 28, 3)\n\n image = tf.concat([image,image], axis=0)\n print(image.shape)\n (56, 28, 3)\n png = tf.image.encode_png(image)\n image2 = concrete_decode(png)\n Traceback (most recent call last):\n\n tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not\n compatible with expected shape [28,28,3].\n\n**Caution:** if you don't use the result of [`tf.ensure_shape`](../tf/ensure_shape) the check may not run. \n\n @tf.function\n def bad_decode_image(png):\n image = tf.image.decode_png(png, channels=3)\n # the `print` executes during tracing.\n print(\"Initial shape: \", image.shape)\n # BAD: forgot to use the returned tensor.\n tf.ensure_shape(image,[28, 28, 3])\n print(\"Final shape: \", image.shape)\n return image\n\n image = bad_decode_image(png)\n Initial shape: (None, None, 3)\n Final shape: (None, None, 3)\n print(image.shape)\n (56, 28, 3)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------|--------------------------------------------------------------------------------------------------------|\n| `x` | A `Tensor`. |\n| `shape` | A `TensorShape` representing the shape of this tensor, a `TensorShapeProto`, a list, a tuple, or None. |\n| `name` | A name for this operation (optional). Defaults to \"EnsureShape\". |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A `Tensor`. Has the same type and contents as `x`. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|---------------------------------------------------------------------------------------------------------------|---------------------------------------------------|\n| [`tf.errors.InvalidArgumentError`](https://www.tensorflow.org/api_docs/python/tf/errors/InvalidArgumentError) | If `shape` is incompatible with the shape of `x`. |\n\n\u003cbr /\u003e"]]