tf.keras.layers.AbstractRNNCell
Stay organized with collections
Save and categorize content based on your preferences.
Abstract object representing an RNN cell.
Inherits From: Layer
tf.keras.layers.AbstractRNNCell(
trainable=True, name=None, dtype=None, dynamic=False, **kwargs
)
This is the base class for implementing RNN cells with custom behavior.
Every RNNCell
must have the properties below and implement call
with
the signature (output, next_state) = call(input, state)
.
Examples:
class MinimalRNNCell(AbstractRNNCell):
def __init__(self, units, **kwargs):
self.units = units
super(MinimalRNNCell, self).__init__(**kwargs)
@property
def state_size(self):
return self.units
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = K.dot(inputs, self.kernel)
output = h + K.dot(prev_output, self.recurrent_kernel)
return output, output
This definition of cell differs from the definition used in the literature.
In the literature, 'cell' refers to an object with a single scalar output.
This definition refers to a horizontal array of such units.
An RNN cell, in the most abstract setting, is anything that has
a state and performs some operation that takes a matrix of inputs.
This operation results in an output matrix with self.output_size
columns.
If self.state_size
is an integer, this operation also results in a new
state matrix with self.state_size
columns. If self.state_size
is a
(possibly nested tuple of) TensorShape object(s), then it should return a
matching structure of Tensors having shape [batch_size].concatenate(s)
for each s
in self.batch_size
.
Attributes |
output_size
|
Integer or TensorShape: size of outputs produced by this cell.
|
state_size
|
size(s) of state(s) used by this cell.
It can be represented by an Integer, a TensorShape or a tuple of Integers
or TensorShapes.
|
Methods
get_initial_state
View source
get_initial_state(
inputs=None, batch_size=None, dtype=None
)
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2020-10-01 UTC."],[],[],null,["# tf.keras.layers.AbstractRNNCell\n\n\u003cbr /\u003e\n\n|-----------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------|\n| [TensorFlow 1 version](/versions/r1.15/api_docs/python/tf/keras/layers/AbstractRNNCell) | [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.0.0/tensorflow/python/keras/layers/recurrent.py#L893-L977) |\n\nAbstract object representing an RNN cell.\n\nInherits From: [`Layer`](../../../tf/keras/layers/Layer)\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.keras.layers.AbstractRNNCell`](/api_docs/python/tf/keras/layers/AbstractRNNCell)\n\n\u003cbr /\u003e\n\n tf.keras.layers.AbstractRNNCell(\n trainable=True, name=None, dtype=None, dynamic=False, **kwargs\n )\n\nThis is the base class for implementing RNN cells with custom behavior.\n\nEvery `RNNCell` must have the properties below and implement `call` with\nthe signature `(output, next_state) = call(input, state)`.\n\n#### Examples:\n\n class MinimalRNNCell(AbstractRNNCell):\n\n def __init__(self, units, **kwargs):\n self.units = units\n super(MinimalRNNCell, self).__init__(**kwargs)\n\n @property\n def state_size(self):\n return self.units\n\n def build(self, input_shape):\n self.kernel = self.add_weight(shape=(input_shape[-1], self.units),\n initializer='uniform',\n name='kernel')\n self.recurrent_kernel = self.add_weight(\n shape=(self.units, self.units),\n initializer='uniform',\n name='recurrent_kernel')\n self.built = True\n\n def call(self, inputs, states):\n prev_output = states[0]\n h = K.dot(inputs, self.kernel)\n output = h + K.dot(prev_output, self.recurrent_kernel)\n return output, output\n\nThis definition of cell differs from the definition used in the literature.\nIn the literature, 'cell' refers to an object with a single scalar output.\nThis definition refers to a horizontal array of such units.\n\nAn RNN cell, in the most abstract setting, is anything that has\na state and performs some operation that takes a matrix of inputs.\nThis operation results in an output matrix with `self.output_size` columns.\nIf `self.state_size` is an integer, this operation also results in a new\nstate matrix with `self.state_size` columns. If `self.state_size` is a\n(possibly nested tuple of) TensorShape object(s), then it should return a\nmatching structure of Tensors having shape `[batch_size].concatenate(s)`\nfor each `s` in `self.batch_size`.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|---------------|------------------------------------------------------------------------------------------------------------------------------------------|\n| `output_size` | Integer or TensorShape: size of outputs produced by this cell. |\n| `state_size` | size(s) of state(s) used by this cell. \u003cbr /\u003e It can be represented by an Integer, a TensorShape or a tuple of Integers or TensorShapes. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `get_initial_state`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.0.0/tensorflow/python/keras/layers/recurrent.py#L976-L977) \n\n get_initial_state(\n inputs=None, batch_size=None, dtype=None\n )"]]