Warning: This project is deprecated. TensorFlow Addons has stopped development,
The project will only be providing minimal maintenance releases until May 2024. See the full
announcement here or on
github.
TensorFlow Addons Layers: WeightNormalization
Stay organized with collections
Save and categorize content based on your preferences.
Overview
This notebook will demonstrate how to use the Weight Normalization layer and how it can improve convergence.
WeightNormalization
A Simple Reparameterization to Accelerate Training of Deep Neural Networks:
Tim Salimans, Diederik P. Kingma (2016)
By reparameterizing the weights in this way you improve the conditioning of the optimization problem and speed up convergence of stochastic gradient descent. Our reparameterization is inspired by batch normalization but does not introduce any dependencies between the examples in a minibatch. This means that our method can also be applied successfully to recurrent models such as LSTMs and to noise-sensitive applications such as deep reinforcement learning or generative models, for which batch normalization is less well suited. Although our method is much simpler, it still provides much of the speed-up of full batch normalization. In addition, the computational overhead of our method is lower, permitting more optimization steps to be taken in the same amount of time.
https://arxiv.org/abs/1602.07868

Setup
pip install -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from matplotlib import pyplot as plt
# Hyper Parameters
batch_size = 32
epochs = 10
num_classes=10
Build Models
# Standard ConvNet
reg_model = tf.keras.Sequential([
tf.keras.layers.Conv2D(6, 5, activation='relu'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(16, 5, activation='relu'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(120, activation='relu'),
tf.keras.layers.Dense(84, activation='relu'),
tf.keras.layers.Dense(num_classes, activation='softmax'),
])
# WeightNorm ConvNet
wn_model = tf.keras.Sequential([
tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(6, 5, activation='relu')),
tf.keras.layers.MaxPooling2D(2, 2),
tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(16, 5, activation='relu')),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Flatten(),
tfa.layers.WeightNormalization(tf.keras.layers.Dense(120, activation='relu')),
tfa.layers.WeightNormalization(tf.keras.layers.Dense(84, activation='relu')),
tfa.layers.WeightNormalization(tf.keras.layers.Dense(num_classes, activation='softmax')),
])
Load Data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 2s 0us/step
Train Models
reg_model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
reg_history = reg_model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
shuffle=True)
Epoch 1/10
1563/1563 [==============================] - 7s 4ms/step - loss: 1.6086 - accuracy: 0.4134 - val_loss: 1.3833 - val_accuracy: 0.4965
Epoch 2/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.3170 - accuracy: 0.5296 - val_loss: 1.2546 - val_accuracy: 0.5553
Epoch 3/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.1944 - accuracy: 0.5776 - val_loss: 1.1566 - val_accuracy: 0.5922
Epoch 4/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.1192 - accuracy: 0.6033 - val_loss: 1.1554 - val_accuracy: 0.5877
Epoch 5/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.0576 - accuracy: 0.6243 - val_loss: 1.1264 - val_accuracy: 0.6028
Epoch 6/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.0041 - accuracy: 0.6441 - val_loss: 1.1555 - val_accuracy: 0.5989
Epoch 7/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.9626 - accuracy: 0.6605 - val_loss: 1.1076 - val_accuracy: 0.6145
Epoch 8/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.9207 - accuracy: 0.6734 - val_loss: 1.1362 - val_accuracy: 0.6128
Epoch 9/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.8836 - accuracy: 0.6872 - val_loss: 1.1191 - val_accuracy: 0.6216
Epoch 10/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.8559 - accuracy: 0.6963 - val_loss: 1.0973 - val_accuracy: 0.6239
wn_model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
wn_history = wn_model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
shuffle=True)
Epoch 1/10
1563/1563 [==============================] - 10s 5ms/step - loss: 1.6049 - accuracy: 0.4162 - val_loss: 1.3947 - val_accuracy: 0.4967
Epoch 2/10
1563/1563 [==============================] - 8s 5ms/step - loss: 1.3372 - accuracy: 0.5218 - val_loss: 1.2908 - val_accuracy: 0.5305
Epoch 3/10
1563/1563 [==============================] - 8s 5ms/step - loss: 1.2231 - accuracy: 0.5647 - val_loss: 1.2343 - val_accuracy: 0.5590
Epoch 4/10
1563/1563 [==============================] - 8s 5ms/step - loss: 1.1362 - accuracy: 0.5980 - val_loss: 1.1979 - val_accuracy: 0.5782
Epoch 5/10
1563/1563 [==============================] - 7s 5ms/step - loss: 1.0799 - accuracy: 0.6159 - val_loss: 1.1784 - val_accuracy: 0.5860
Epoch 6/10
1563/1563 [==============================] - 8s 5ms/step - loss: 1.0261 - accuracy: 0.6360 - val_loss: 1.1850 - val_accuracy: 0.5818
Epoch 7/10
1563/1563 [==============================] - 8s 5ms/step - loss: 0.9798 - accuracy: 0.6538 - val_loss: 1.1201 - val_accuracy: 0.6082
Epoch 8/10
1563/1563 [==============================] - 8s 5ms/step - loss: 0.9349 - accuracy: 0.6701 - val_loss: 1.1204 - val_accuracy: 0.6136
Epoch 9/10
1563/1563 [==============================] - 8s 5ms/step - loss: 0.8955 - accuracy: 0.6829 - val_loss: 1.1292 - val_accuracy: 0.6103
Epoch 10/10
1563/1563 [==============================] - 8s 5ms/step - loss: 0.8555 - accuracy: 0.6995 - val_loss: 1.1445 - val_accuracy: 0.6103
reg_accuracy = reg_history.history['accuracy']
wn_accuracy = wn_history.history['accuracy']
plt.plot(np.linspace(0, epochs, epochs), reg_accuracy,
color='red', label='Regular ConvNet')
plt.plot(np.linspace(0, epochs, epochs), wn_accuracy,
color='blue', label='WeightNorm ConvNet')
plt.title('WeightNorm Accuracy Comparison')
plt.legend()
plt.grid(True)
plt.show()

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-05-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-05-26 UTC."],[],[],null,["# TensorFlow Addons Layers: WeightNormalization\n\n\u003cbr /\u003e\n\n|--------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|\n| [View on TensorFlow.org](https://www.tensorflow.org/addons/tutorials/layers_weightnormalization) | [Run in Google Colab](https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/layers_weightnormalization.ipynb) | [View source on GitHub](https://github.com/tensorflow/addons/blob/master/docs/tutorials/layers_weightnormalization.ipynb) | [Download notebook](https://storage.googleapis.com/tensorflow_docs/addons/docs/tutorials/layers_weightnormalization.ipynb) |\n\nOverview\n--------\n\nThis notebook will demonstrate how to use the Weight Normalization layer and how it can improve convergence.\n\nWeightNormalization\n===================\n\nA Simple Reparameterization to Accelerate Training of Deep Neural Networks:\n\nTim Salimans, Diederik P. Kingma (2016)\n\u003e By reparameterizing the weights in this way you improve the conditioning of the optimization problem and speed up convergence of stochastic gradient descent. Our reparameterization is inspired by batch normalization but does not introduce any dependencies between the examples in a minibatch. This means that our method can also be applied successfully to recurrent models such as LSTMs and to noise-sensitive applications such as deep reinforcement learning or generative models, for which batch normalization is less well suited. Although our method is much simpler, it still provides much of the speed-up of full batch normalization. In addition, the computational overhead of our method is lower, permitting more optimization steps to be taken in the same amount of time.\n\u003e \u003chttps://arxiv.org/abs/1602.07868\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\nSetup\n-----\n\n pip install -U tensorflow-addons\n\n import tensorflow as tf\n import tensorflow_addons as tfa\n\n import numpy as np\n from matplotlib import pyplot as plt\n\n # Hyper Parameters\n batch_size = 32\n epochs = 10\n num_classes=10\n\nBuild Models\n------------\n\n # Standard ConvNet\n reg_model = tf.keras.Sequential([\n tf.keras.layers.Conv2D(6, 5, activation='relu'),\n tf.keras.layers.MaxPooling2D(2, 2),\n tf.keras.layers.Conv2D(16, 5, activation='relu'),\n tf.keras.layers.MaxPooling2D(2, 2),\n tf.keras.layers.Flatten(),\n tf.keras.layers.Dense(120, activation='relu'),\n tf.keras.layers.Dense(84, activation='relu'),\n tf.keras.layers.Dense(num_classes, activation='softmax'),\n ])\n\n # WeightNorm ConvNet\n wn_model = tf.keras.Sequential([\n tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(6, 5, activation='relu')),\n tf.keras.layers.MaxPooling2D(2, 2),\n tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(16, 5, activation='relu')),\n tf.keras.layers.MaxPooling2D(2, 2),\n tf.keras.layers.Flatten(),\n tfa.layers.WeightNormalization(tf.keras.layers.Dense(120, activation='relu')),\n tfa.layers.WeightNormalization(tf.keras.layers.Dense(84, activation='relu')),\n tfa.layers.WeightNormalization(tf.keras.layers.Dense(num_classes, activation='softmax')),\n ])\n\nLoad Data\n---------\n\n (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()\n\n # Convert class vectors to binary class matrices.\n y_train = tf.keras.utils.to_categorical(y_train, num_classes)\n y_test = tf.keras.utils.to_categorical(y_test, num_classes)\n\n x_train = x_train.astype('float32')\n x_test = x_test.astype('float32')\n x_train /= 255\n x_test /= 255\n\n```\nDownloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\n170498071/170498071 [==============================] - 2s 0us/step\n```\n\nTrain Models\n------------\n\n reg_model.compile(optimizer='adam', \n loss='categorical_crossentropy',\n metrics=['accuracy'])\n\n reg_history = reg_model.fit(x_train, y_train,\n batch_size=batch_size,\n epochs=epochs,\n validation_data=(x_test, y_test),\n shuffle=True)\n\n```\nEpoch 1/10\n1563/1563 [==============================] - 7s 4ms/step - loss: 1.6086 - accuracy: 0.4134 - val_loss: 1.3833 - val_accuracy: 0.4965\nEpoch 2/10\n1563/1563 [==============================] - 5s 3ms/step - loss: 1.3170 - accuracy: 0.5296 - val_loss: 1.2546 - val_accuracy: 0.5553\nEpoch 3/10\n1563/1563 [==============================] - 5s 3ms/step - loss: 1.1944 - accuracy: 0.5776 - val_loss: 1.1566 - val_accuracy: 0.5922\nEpoch 4/10\n1563/1563 [==============================] - 5s 3ms/step - loss: 1.1192 - accuracy: 0.6033 - val_loss: 1.1554 - val_accuracy: 0.5877\nEpoch 5/10\n1563/1563 [==============================] - 5s 3ms/step - loss: 1.0576 - accuracy: 0.6243 - val_loss: 1.1264 - val_accuracy: 0.6028\nEpoch 6/10\n1563/1563 [==============================] - 5s 3ms/step - loss: 1.0041 - accuracy: 0.6441 - val_loss: 1.1555 - val_accuracy: 0.5989\nEpoch 7/10\n1563/1563 [==============================] - 5s 3ms/step - loss: 0.9626 - accuracy: 0.6605 - val_loss: 1.1076 - val_accuracy: 0.6145\nEpoch 8/10\n1563/1563 [==============================] - 5s 3ms/step - loss: 0.9207 - accuracy: 0.6734 - val_loss: 1.1362 - val_accuracy: 0.6128\nEpoch 9/10\n1563/1563 [==============================] - 5s 3ms/step - loss: 0.8836 - accuracy: 0.6872 - val_loss: 1.1191 - val_accuracy: 0.6216\nEpoch 10/10\n1563/1563 [==============================] - 5s 3ms/step - loss: 0.8559 - accuracy: 0.6963 - val_loss: 1.0973 - val_accuracy: 0.6239\n``` \n\n wn_model.compile(optimizer='adam', \n loss='categorical_crossentropy',\n metrics=['accuracy'])\n\n wn_history = wn_model.fit(x_train, y_train,\n batch_size=batch_size,\n epochs=epochs,\n validation_data=(x_test, y_test),\n shuffle=True)\n\n```\nEpoch 1/10\n1563/1563 [==============================] - 10s 5ms/step - loss: 1.6049 - accuracy: 0.4162 - val_loss: 1.3947 - val_accuracy: 0.4967\nEpoch 2/10\n1563/1563 [==============================] - 8s 5ms/step - loss: 1.3372 - accuracy: 0.5218 - val_loss: 1.2908 - val_accuracy: 0.5305\nEpoch 3/10\n1563/1563 [==============================] - 8s 5ms/step - loss: 1.2231 - accuracy: 0.5647 - val_loss: 1.2343 - val_accuracy: 0.5590\nEpoch 4/10\n1563/1563 [==============================] - 8s 5ms/step - loss: 1.1362 - accuracy: 0.5980 - val_loss: 1.1979 - val_accuracy: 0.5782\nEpoch 5/10\n1563/1563 [==============================] - 7s 5ms/step - loss: 1.0799 - accuracy: 0.6159 - val_loss: 1.1784 - val_accuracy: 0.5860\nEpoch 6/10\n1563/1563 [==============================] - 8s 5ms/step - loss: 1.0261 - accuracy: 0.6360 - val_loss: 1.1850 - val_accuracy: 0.5818\nEpoch 7/10\n1563/1563 [==============================] - 8s 5ms/step - loss: 0.9798 - accuracy: 0.6538 - val_loss: 1.1201 - val_accuracy: 0.6082\nEpoch 8/10\n1563/1563 [==============================] - 8s 5ms/step - loss: 0.9349 - accuracy: 0.6701 - val_loss: 1.1204 - val_accuracy: 0.6136\nEpoch 9/10\n1563/1563 [==============================] - 8s 5ms/step - loss: 0.8955 - accuracy: 0.6829 - val_loss: 1.1292 - val_accuracy: 0.6103\nEpoch 10/10\n1563/1563 [==============================] - 8s 5ms/step - loss: 0.8555 - accuracy: 0.6995 - val_loss: 1.1445 - val_accuracy: 0.6103\n``` \n\n reg_accuracy = reg_history.history['accuracy']\n wn_accuracy = wn_history.history['accuracy']\n\n plt.plot(np.linspace(0, epochs, epochs), reg_accuracy,\n color='red', label='Regular ConvNet')\n\n plt.plot(np.linspace(0, epochs, epochs), wn_accuracy,\n color='blue', label='WeightNorm ConvNet')\n\n plt.title('WeightNorm Accuracy Comparison')\n plt.legend()\n plt.grid(True)\n plt.show()"]]