Zobacz na TensorFlow.org | Uruchom w Google Colab | Wyświetl źródło na GitHub | Pobierz notatnik |
Przegląd
Wykres regularyzacji jest techniką specyficzne pod szerszym paradygmatu Neural Graph uczenia się ( Bui i wsp., 2018 ). Podstawową ideą jest trenowanie modeli sieci neuronowych z celem uregulowanym grafem, wykorzystującym zarówno dane oznaczone, jak i nieoznakowane.
W tym samouczku zbadamy zastosowanie regularyzacji grafów do klasyfikowania dokumentów, które tworzą graf naturalny (organiczny).
Ogólna recepta na stworzenie modelu z regularnym wykresem przy użyciu struktury Neural Structured Learning (NSL) jest następująca:
- Generuj dane treningowe z wykresu wejściowego i przykładowych funkcji. Węzły na grafie odpowiadają próbkom, a krawędzie na grafie odpowiadają podobieństwu między parami próbek. Otrzymane dane treningowe będą zawierały cechy sąsiednie oprócz oryginalnych cech węzła.
- Tworzenie sieci neuronowych jako modelu podstawowego z wykorzystaniem
Kerassekwencyjny, funkcjonalne, lub podklasy API. - Owijać modelu podłoża z
GraphRegularizationklasy osłony, która jest wyposażona w ramach NSL, aby utworzyć nowy wykresKerasmodel. Ten nowy model będzie zawierał utratę regularyzacji wykresu jako termin regularyzacji w swoim celu treningowym. - Szkolić i oceniać wykresie
Kerasmodel.
Ustawiać
Zainstaluj pakiet Neural Structured Learning.
pip install --quiet neural-structured-learning
Zależności i importy
import neural_structured_learning as nsl
import tensorflow as tf
# Resets notebook state
tf.keras.backend.clear_session()
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
"GPU is",
"available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version: 2.8.0-rc0 Eager mode: True GPU is NOT AVAILABLE 2022-01-05 12:39:27.704660: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Zbiór danych Cora
Cora zestaw danych jest wykresem cytat gdzie węzły stanowią papiery uczenia maszynowego i krawędzie stanowią cytaty pomiędzy parami papierów. Zadaniem jest klasyfikacja dokumentów, której celem jest przyporządkowanie każdego artykułu do jednej z 7 kategorii. Innymi słowy, jest to problem klasyfikacji wieloklasowej z 7 klasami.
Wykres
Oryginalny wykres jest skierowany. Jednak na potrzeby tego przykładu rozważymy nieskierowaną wersję tego wykresu. Tak więc, jeśli artykuł A cytuje artykuł B, uważamy również, że artykuł B zacytował artykuł A. Chociaż niekoniecznie jest to prawdą, w tym przykładzie traktujemy cytaty jako wskaźnik podobieństwa, które jest zwykle własnością przemienną.
Cechy
Każdy papier w wejściu faktycznie zawiera 2 funkcje:
Słowa: Gęsta, multi-hot bag-of-słów reprezentacja tekstu w dokumencie. Słownik zestawu danych Cora zawiera 1433 unikalnych słów. Zatem długość tej cechy wynosi 1433, a wartość na pozycji „i” wynosi 0/1, wskazując, czy słowo „i” w słowniku istnieje w danej pracy, czy nie.
Etykieta: pojedyncza liczba całkowita identyfikator klasy (kategoria) papieru.
Pobierz zestaw danych Cora
wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgztar -C /tmp -xvzf /tmp/cora.tgz
cora/ cora/README cora/cora.cites cora/cora.content
Konwertuj dane Cora na format NSL
W celu Preprocesuj zestawu danych Cora i przekonwertować go do formatu wymaganego przez Neural Structured Learning, będziemy uruchomić „preprocess_cora_dataset.py” skrypt, który jest zawarty w repozytorium NSL github. Ten skrypt wykonuje następujące czynności:
- Generuj obiekty sąsiednie, korzystając z oryginalnych obiektów węzła i wykresu.
- Generowanie danych testowych pociągów i podziały zawierające
tf.train.Exampleinstancji. - Utrzymują wynikowy pociąg i danych testowych w
TFRecordformacie.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2022-01-05 12:39:28-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’ preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0s 2022-01-05 12:39:28 (78.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640] 2022-01-05 12:39:31.378912: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.01 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.36 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.04 minutes.
Zmienne globalne
Ścieżki plików z danymi kolejowych i badania oparte są na wartościach flagi linii poleceń używanych do wywołania „preprocess_cora_dataset.py” skrypt powyżej.
### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'
### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
Hiperparametry
Użyjemy wystąpienie HParams zawierać różne hiperparametrów i stałych wykorzystywanych do szkolenia i oceny. Poniżej krótko opisujemy każdy z nich:
num_classes: Istnieje w sumie 7 różnych klas
max_seq_length: Jest to wielkość słownika i wszystkich przypadkach, wejście znajduje się gęsta wielu gorąco, worki z Słowa reprezentacji. Innymi słowy, wartość 1 dla słowa oznacza, że słowo jest obecne na wejściu, a wartość 0 oznacza, że tak nie jest.
distance_type: Jest to odległość metryczny używany w celu uregulowania próbki z sąsiadami.
graph_regularization_multiplier: Steruje względnej masie terminu wykres regularyzacji w ogólnej funkcji strat.
num_neighbors: liczba sąsiadów wykorzystywanych do wykresu uregulowania. Ta wartość musi być mniejsza lub równa
max_nbrsargument wiersza polecenia używane podczas pracy nadpreprocess_cora_dataset.py.num_fc_units: liczba połączonych warstw w pełni naszej sieci neuronowej.
train_epochs: liczba epok szkoleniowych.
Wielkość partii wykorzystywane do szkolenia i oceny: batch_size.
dropout_rate: kontroluje szybkość widełek po każdym pełni połączonych warstw
eval_steps: Numer partii do procesu przed uznania zakonczeniu. Jeśli ustawione na
None, wszystkie instancje w zestawie testowym są oceniane.
class HParams(object):
"""Hyperparameters used for training."""
def __init__(self):
### dataset parameters
self.num_classes = 7
self.max_seq_length = 1433
### neural graph learning parameters
self.distance_type = nsl.configs.DistanceType.L2
self.graph_regularization_multiplier = 0.1
self.num_neighbors = 1
### model architecture
self.num_fc_units = [50, 50]
### training parameters
self.train_epochs = 100
self.batch_size = 128
self.dropout_rate = 0.5
### eval parameters
self.eval_steps = None # All instances in the test set are evaluated.
HPARAMS = HParams()
Załaduj pociąg i dane testowe
Jak opisano wcześniej w tym notebooku, dane wejściowe i testy szkoleniowe zostały stworzone przez „preprocess_cora_dataset.py”. Będziemy załadować je na dwie tf.data.Dataset obiekty - jeden dla pociągu i jeden dla testu.
W warstwie wejściowej naszego modelu, będziemy ekstraktu nie tylko „słowa” i „etykieta” wyposażony z każdej próbki, ale również odpowiada sąsiad funkcje oparte na hparams.num_neighbors wartości. Przypadki z mniejszymi sąsiadami niż hparams.num_neighbors zostanie przypisany manekin wartości dla tych nieistniejących cech sąsiada.
def make_dataset(file_path, training=False):
"""Creates a `tf.data.TFRecordDataset`.
Args:
file_path: Name of the file in the `.tfrecord` format containing
`tf.train.Example` objects.
training: Boolean indicating if we are in training mode.
Returns:
An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
objects.
"""
def parse_example(example_proto):
"""Extracts relevant fields from the `example_proto`.
Args:
example_proto: An instance of `tf.train.Example`.
Returns:
A pair whose first value is a dictionary containing relevant features
and whose second value contains the ground truth label.
"""
# The 'words' feature is a multi-hot, bag-of-words representation of the
# original raw text. A default value is required for examples that don't
# have the feature.
feature_spec = {
'words':
tf.io.FixedLenFeature([HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0,
dtype=tf.int64,
shape=[HPARAMS.max_seq_length])),
'label':
tf.io.FixedLenFeature((), tf.int64, default_value=-1),
}
# We also extract corresponding neighbor features in a similar manner to
# the features above during training.
if training:
for i in range(HPARAMS.num_neighbors):
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
NBR_WEIGHT_SUFFIX)
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
[HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))
# We assign a default value of 0.0 for the neighbor weight so that
# graph regularization is done on samples based on their exact number
# of neighbors. In other words, non-existent neighbors are discounted.
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
[1], tf.float32, default_value=tf.constant([0.0]))
features = tf.io.parse_single_example(example_proto, feature_spec)
label = features.pop('label')
return features, label
dataset = tf.data.TFRecordDataset([file_path])
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.map(parse_example)
dataset = dataset.batch(HPARAMS.batch_size)
return dataset
train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
Zajrzyjmy do zbioru danych pociągu, aby przyjrzeć się jego zawartości.
for feature_batch, label_batch in train_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
print('Batch of neighbor weights:',
tf.reshape(feature_batch[nbr_weight_key], [-1]))
print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 1 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [2 2 6 2 0 6 1 3 5 0 1 2 3 6 1 1 0 3 5 2 3 1 4 1 6 1 3 2 2 2 0 3 2 1 3 3 2 3 3 2 3 2 2 0 2 2 6 0 2 1 1 0 5 2 1 4 2 1 2 4 0 2 5 4 3 6 3 2 1 6 2 4 2 2 6 4 6 4 3 5 2 2 2 4 2 2 2 1 2 2 2 4 2 3 6 2 0 6 6 0 2 6 2 1 2 0 1 1 3 2 0 2 0 2 1 1 3 5 2 1 2 5 1 6 2 4 6 4], shape=(128,), dtype=int64)
Zajrzyjmy do testowego zbioru danych, aby przyjrzeć się jego zawartości.
for feature_batch, label_batch in test_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
print('Batch of labels:', label_batch)
Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)
Definicja modelu
Aby zademonstrować użycie regularyzacji grafów, najpierw budujemy model bazowy dla tego problemu. Użyjemy prostej sieci neuronowej typu feed-forward z 2 ukrytymi warstwami i przerwami pomiędzy nimi. Przedstawimy utworzenie modelu podstawowego z wykorzystaniem wszystkich rodzajów produktu obsługiwane przez tf.Keras ram - sekwencyjne, funkcjonalne i podklasy.
Sekwencyjny model podstawowy
def make_mlp_sequential_model(hparams):
"""Creates a sequential multi-layer perceptron model."""
model = tf.keras.Sequential()
model.add(
tf.keras.layers.InputLayer(
input_shape=(hparams.max_seq_length,), name='words'))
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
model.add(
tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
for num_units in hparams.num_fc_units:
model.add(tf.keras.layers.Dense(num_units, activation='relu'))
# For sequential models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
model.add(tf.keras.layers.Dense(hparams.num_classes))
return model
Funkcjonalny model podstawowy
def make_mlp_functional_model(hparams):
"""Creates a functional API-based multi-layer perceptron model."""
inputs = tf.keras.Input(
shape=(hparams.max_seq_length,), dtype='int64', name='words')
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
cur_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))(
inputs)
for num_units in hparams.num_fc_units:
cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
# For functional models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)
outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)
model = tf.keras.Model(inputs, outputs=outputs)
return model
Model podstawowy podklasy
def make_mlp_subclass_model(hparams):
"""Creates a multi-layer perceptron subclass model in Keras."""
class MLP(tf.keras.Model):
"""Subclass model defining a multi-layer perceptron."""
def __init__(self):
super(MLP, self).__init__()
# Input is already one-hot encoded in the integer format. We create a
# layer to cast it to floating point format here.
self.cast_to_float_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))
self.dense_layers = [
tf.keras.layers.Dense(num_units, activation='relu')
for num_units in hparams.num_fc_units
]
self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
self.output_layer = tf.keras.layers.Dense(hparams.num_classes)
def call(self, inputs, training=False):
cur_layer = self.cast_to_float_layer(inputs['words'])
for dense_layer in self.dense_layers:
cur_layer = dense_layer(cur_layer)
cur_layer = self.dropout_layer(cur_layer, training=training)
outputs = self.output_layer(cur_layer)
return outputs
return MLP()
Utwórz model(e) bazowy(e)
# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
words (InputLayer) [(None, 1433)] 0
lambda (Lambda) (None, 1433) 0
dense (Dense) (None, 50) 71700
dropout (Dropout) (None, 50) 0
dense_1 (Dense) (None, 50) 2550
dropout_1 (Dropout) (None, 50) 0
dense_2 (Dense) (None, 7) 357
=================================================================
Total params: 74,607
Trainable params: 74,607
Non-trainable params: 0
_________________________________________________________________
Model bazy kolejowej MLP
# Compile and train the base MLP model
base_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/functional.py:559: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model. inputs = self._flatten_to_reference_inputs(inputs) 17/17 [==============================] - 1s 18ms/step - loss: 1.9521 - accuracy: 0.1838 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8590 - accuracy: 0.3044 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7770 - accuracy: 0.3601 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6655 - accuracy: 0.3898 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5386 - accuracy: 0.4543 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3856 - accuracy: 0.5077 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2736 - accuracy: 0.5531 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1636 - accuracy: 0.5889 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0654 - accuracy: 0.6385 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9703 - accuracy: 0.6761 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8689 - accuracy: 0.7104 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7704 - accuracy: 0.7494 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7157 - accuracy: 0.7810 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6296 - accuracy: 0.8186 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5932 - accuracy: 0.8167 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5526 - accuracy: 0.8464 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5112 - accuracy: 0.8445 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4624 - accuracy: 0.8613 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4163 - accuracy: 0.8696 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3808 - accuracy: 0.8849 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3564 - accuracy: 0.8933 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3453 - accuracy: 0.9002 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3226 - accuracy: 0.9114 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3058 - accuracy: 0.9151 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2798 - accuracy: 0.9146 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2638 - accuracy: 0.9248 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2538 - accuracy: 0.9290 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2356 - accuracy: 0.9411 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2080 - accuracy: 0.9425 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2172 - accuracy: 0.9364 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2259 - accuracy: 0.9225 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1944 - accuracy: 0.9480 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1892 - accuracy: 0.9434 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1718 - accuracy: 0.9592 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1826 - accuracy: 0.9508 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1585 - accuracy: 0.9559 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1605 - accuracy: 0.9545 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1529 - accuracy: 0.9550 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1411 - accuracy: 0.9615 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1366 - accuracy: 0.9624 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1431 - accuracy: 0.9578 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1241 - accuracy: 0.9619 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1310 - accuracy: 0.9661 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1284 - accuracy: 0.9652 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1215 - accuracy: 0.9633 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1130 - accuracy: 0.9722 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1074 - accuracy: 0.9722 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1143 - accuracy: 0.9694 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1015 - accuracy: 0.9740 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1077 - accuracy: 0.9698 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1035 - accuracy: 0.9684 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9694 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1000 - accuracy: 0.9689 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0967 - accuracy: 0.9749 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0994 - accuracy: 0.9703 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0943 - accuracy: 0.9740 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0923 - accuracy: 0.9735 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0848 - accuracy: 0.9800 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0836 - accuracy: 0.9782 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0913 - accuracy: 0.9735 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0823 - accuracy: 0.9773 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0753 - accuracy: 0.9810 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0746 - accuracy: 0.9777 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0861 - accuracy: 0.9731 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0765 - accuracy: 0.9787 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0750 - accuracy: 0.9791 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0725 - accuracy: 0.9814 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0762 - accuracy: 0.9791 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0645 - accuracy: 0.9842 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0606 - accuracy: 0.9861 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0775 - accuracy: 0.9805 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0655 - accuracy: 0.9800 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0629 - accuracy: 0.9833 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0625 - accuracy: 0.9824 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0607 - accuracy: 0.9838 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0578 - accuracy: 0.9824 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0568 - accuracy: 0.9842 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0595 - accuracy: 0.9833 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0615 - accuracy: 0.9842 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0555 - accuracy: 0.9852 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0517 - accuracy: 0.9870 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0541 - accuracy: 0.9856 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0533 - accuracy: 0.9884 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0509 - accuracy: 0.9838 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0600 - accuracy: 0.9828 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0617 - accuracy: 0.9800 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0599 - accuracy: 0.9800 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0502 - accuracy: 0.9870 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0416 - accuracy: 0.9907 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0542 - accuracy: 0.9842 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0490 - accuracy: 0.9847 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0374 - accuracy: 0.9916 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0467 - accuracy: 0.9893 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0426 - accuracy: 0.9879 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0543 - accuracy: 0.9861 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0420 - accuracy: 0.9870 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0461 - accuracy: 0.9861 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0425 - accuracy: 0.9898 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0406 - accuracy: 0.9907 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0486 - accuracy: 0.9847 <keras.callbacks.History at 0x7f6f9d5eacd0>
Oceń podstawowy model MLP
# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
"""Prints evaluation metrics.
Args:
model_desc: A description of the model.
eval_metrics: A dictionary mapping metric names to corresponding values. It
must contain the loss and accuracy metrics.
"""
print('\n')
print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
if 'graph_loss' in eval_metrics:
print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
zip(base_model.metrics_names,
base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.4192 - accuracy: 0.7939 Eval accuracy for Base MLP model : 0.7938517332077026 Eval loss for Base MLP model : 1.4192423820495605
Trenuj model MLP z regularyzacją wykresu
Zawierające wykres uregulowania w perspektywie utraty dotychczasowych tf.Keras.Model wymaga zaledwie kilku linii kodu. Podstawowy model jest owijany, aby utworzyć nowy tf.Keras modelu podklasy, którego strata zawiera wykres regularyzacji.
Aby ocenić przyrostowe korzyści z regularyzacji wykresów, utworzymy nową instancję modelu bazowego. To dlatego base_model został już przeszkolony przez kilka iteracji i ponownego wykorzystania tego wyszkolony modelu stworzyć model wykres-uregulowana nie będzie porównanie fair base_model .
# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
max_neighbors=HPARAMS.num_neighbors,
multiplier=HPARAMS.graph_regularization_multiplier,
distance_type=HPARAMS.distance_type,
sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
graph_reg_config)
graph_reg_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape:0", shape=(None, 7), dtype=float32), dense_shape=Tensor("gradient_tape/GraphRegularization/graph_loss/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.
"shape. This may consume a large amount of memory." % value)
17/17 [==============================] - 2s 4ms/step - loss: 1.9798 - accuracy: 0.1601 - scaled_graph_loss: 0.0373
Epoch 2/100
17/17 [==============================] - 0s 3ms/step - loss: 1.9024 - accuracy: 0.2979 - scaled_graph_loss: 0.0254
Epoch 3/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8623 - accuracy: 0.3160 - scaled_graph_loss: 0.0317
Epoch 4/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8042 - accuracy: 0.3443 - scaled_graph_loss: 0.0498
Epoch 5/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7552 - accuracy: 0.3582 - scaled_graph_loss: 0.0696
Epoch 6/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7012 - accuracy: 0.4084 - scaled_graph_loss: 0.0866
Epoch 7/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6578 - accuracy: 0.4515 - scaled_graph_loss: 0.1114
Epoch 8/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6058 - accuracy: 0.5039 - scaled_graph_loss: 0.1300
Epoch 9/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5498 - accuracy: 0.5434 - scaled_graph_loss: 0.1508
Epoch 10/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5098 - accuracy: 0.6019 - scaled_graph_loss: 0.1651
Epoch 11/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4746 - accuracy: 0.6302 - scaled_graph_loss: 0.1844
Epoch 12/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4315 - accuracy: 0.6520 - scaled_graph_loss: 0.1917
Epoch 13/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3932 - accuracy: 0.6770 - scaled_graph_loss: 0.2024
Epoch 14/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3645 - accuracy: 0.7183 - scaled_graph_loss: 0.2145
Epoch 15/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3265 - accuracy: 0.7369 - scaled_graph_loss: 0.2324
Epoch 16/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3045 - accuracy: 0.7555 - scaled_graph_loss: 0.2358
Epoch 17/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2836 - accuracy: 0.7652 - scaled_graph_loss: 0.2404
Epoch 18/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2456 - accuracy: 0.7898 - scaled_graph_loss: 0.2469
Epoch 19/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2348 - accuracy: 0.8074 - scaled_graph_loss: 0.2615
Epoch 20/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2000 - accuracy: 0.8074 - scaled_graph_loss: 0.2542
Epoch 21/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1994 - accuracy: 0.8260 - scaled_graph_loss: 0.2729
Epoch 22/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1825 - accuracy: 0.8269 - scaled_graph_loss: 0.2676
Epoch 23/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1598 - accuracy: 0.8455 - scaled_graph_loss: 0.2742
Epoch 24/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8534 - scaled_graph_loss: 0.2797
Epoch 25/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1456 - accuracy: 0.8552 - scaled_graph_loss: 0.2714
Epoch 26/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8566 - scaled_graph_loss: 0.2796
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1150 - accuracy: 0.8687 - scaled_graph_loss: 0.2850
Epoch 28/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8626 - scaled_graph_loss: 0.2772
Epoch 29/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0806 - accuracy: 0.8733 - scaled_graph_loss: 0.2756
Epoch 30/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0828 - accuracy: 0.8626 - scaled_graph_loss: 0.2907
Epoch 31/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0724 - accuracy: 0.8886 - scaled_graph_loss: 0.2834
Epoch 32/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0589 - accuracy: 0.8826 - scaled_graph_loss: 0.2881
Epoch 33/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0490 - accuracy: 0.8872 - scaled_graph_loss: 0.2972
Epoch 34/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0550 - accuracy: 0.8923 - scaled_graph_loss: 0.2935
Epoch 35/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0397 - accuracy: 0.8840 - scaled_graph_loss: 0.2795
Epoch 36/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0360 - accuracy: 0.8891 - scaled_graph_loss: 0.2966
Epoch 37/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0235 - accuracy: 0.8961 - scaled_graph_loss: 0.2890
Epoch 38/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0219 - accuracy: 0.8984 - scaled_graph_loss: 0.2965
Epoch 39/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0168 - accuracy: 0.9044 - scaled_graph_loss: 0.3023
Epoch 40/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0148 - accuracy: 0.9035 - scaled_graph_loss: 0.2984
Epoch 41/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9118 - scaled_graph_loss: 0.2888
Epoch 42/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0019 - accuracy: 0.9021 - scaled_graph_loss: 0.2877
Epoch 43/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9049 - scaled_graph_loss: 0.2912
Epoch 44/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9986 - accuracy: 0.9026 - scaled_graph_loss: 0.3040
Epoch 45/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9939 - accuracy: 0.9067 - scaled_graph_loss: 0.3016
Epoch 46/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9828 - accuracy: 0.9058 - scaled_graph_loss: 0.2877
Epoch 47/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9629 - accuracy: 0.9137 - scaled_graph_loss: 0.2844
Epoch 48/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9645 - accuracy: 0.9146 - scaled_graph_loss: 0.2933
Epoch 49/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9752 - accuracy: 0.9165 - scaled_graph_loss: 0.3013
Epoch 50/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9552 - accuracy: 0.9179 - scaled_graph_loss: 0.2865
Epoch 51/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9539 - accuracy: 0.9193 - scaled_graph_loss: 0.3044
Epoch 52/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9443 - accuracy: 0.9183 - scaled_graph_loss: 0.3010
Epoch 53/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9559 - accuracy: 0.9244 - scaled_graph_loss: 0.2987
Epoch 54/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9497 - accuracy: 0.9225 - scaled_graph_loss: 0.2979
Epoch 55/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9674 - accuracy: 0.9183 - scaled_graph_loss: 0.3034
Epoch 56/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9537 - accuracy: 0.9174 - scaled_graph_loss: 0.2834
Epoch 57/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9341 - accuracy: 0.9188 - scaled_graph_loss: 0.2939
Epoch 58/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9392 - accuracy: 0.9225 - scaled_graph_loss: 0.2998
Epoch 59/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9240 - accuracy: 0.9313 - scaled_graph_loss: 0.3022
Epoch 60/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9368 - accuracy: 0.9267 - scaled_graph_loss: 0.2979
Epoch 61/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9306 - accuracy: 0.9234 - scaled_graph_loss: 0.2952
Epoch 62/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9197 - accuracy: 0.9230 - scaled_graph_loss: 0.2916
Epoch 63/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9360 - accuracy: 0.9206 - scaled_graph_loss: 0.2947
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9181 - accuracy: 0.9299 - scaled_graph_loss: 0.2996
Epoch 65/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9105 - accuracy: 0.9341 - scaled_graph_loss: 0.2981
Epoch 66/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9014 - accuracy: 0.9323 - scaled_graph_loss: 0.2897
Epoch 67/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9059 - accuracy: 0.9364 - scaled_graph_loss: 0.3083
Epoch 68/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9053 - accuracy: 0.9309 - scaled_graph_loss: 0.2976
Epoch 69/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9099 - accuracy: 0.9258 - scaled_graph_loss: 0.3069
Epoch 70/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9025 - accuracy: 0.9355 - scaled_graph_loss: 0.2890
Epoch 71/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8849 - accuracy: 0.9281 - scaled_graph_loss: 0.2933
Epoch 72/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8959 - accuracy: 0.9323 - scaled_graph_loss: 0.2918
Epoch 73/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9074 - accuracy: 0.9248 - scaled_graph_loss: 0.3065
Epoch 74/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8845 - accuracy: 0.9369 - scaled_graph_loss: 0.2874
Epoch 75/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8873 - accuracy: 0.9401 - scaled_graph_loss: 0.2996
Epoch 76/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8942 - accuracy: 0.9327 - scaled_graph_loss: 0.3086
Epoch 77/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9052 - accuracy: 0.9253 - scaled_graph_loss: 0.2986
Epoch 78/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8811 - accuracy: 0.9336 - scaled_graph_loss: 0.2948
Epoch 79/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8896 - accuracy: 0.9276 - scaled_graph_loss: 0.2919
Epoch 80/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8853 - accuracy: 0.9313 - scaled_graph_loss: 0.2944
Epoch 81/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8875 - accuracy: 0.9323 - scaled_graph_loss: 0.2925
Epoch 82/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8639 - accuracy: 0.9323 - scaled_graph_loss: 0.2967
Epoch 83/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8820 - accuracy: 0.9332 - scaled_graph_loss: 0.3047
Epoch 84/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8752 - accuracy: 0.9346 - scaled_graph_loss: 0.2942
Epoch 85/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9374 - scaled_graph_loss: 0.3066
Epoch 86/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9332 - scaled_graph_loss: 0.2881
Epoch 87/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8691 - accuracy: 0.9420 - scaled_graph_loss: 0.3030
Epoch 88/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8631 - accuracy: 0.9374 - scaled_graph_loss: 0.2916
Epoch 89/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9392 - scaled_graph_loss: 0.3032
Epoch 90/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8632 - accuracy: 0.9420 - scaled_graph_loss: 0.3019
Epoch 91/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8600 - accuracy: 0.9425 - scaled_graph_loss: 0.2965
Epoch 92/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8569 - accuracy: 0.9346 - scaled_graph_loss: 0.2977
Epoch 93/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8704 - accuracy: 0.9374 - scaled_graph_loss: 0.3083
Epoch 94/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8562 - accuracy: 0.9406 - scaled_graph_loss: 0.2883
Epoch 95/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8545 - accuracy: 0.9415 - scaled_graph_loss: 0.3030
Epoch 96/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8592 - accuracy: 0.9332 - scaled_graph_loss: 0.2927
Epoch 97/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8503 - accuracy: 0.9397 - scaled_graph_loss: 0.2927
Epoch 98/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8434 - accuracy: 0.9462 - scaled_graph_loss: 0.2937
Epoch 99/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8578 - accuracy: 0.9374 - scaled_graph_loss: 0.3064
Epoch 100/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8504 - accuracy: 0.9411 - scaled_graph_loss: 0.3043
<keras.callbacks.History at 0x7f70041be650>
Oceń model MLP z regularyzacją wykresu
eval_results = dict(
zip(graph_reg_model.metrics_names,
graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8884 - accuracy: 0.7957 Eval accuracy for MLP + graph regularization : 0.7956600189208984 Eval loss for MLP + graph regularization : 0.8883611559867859
Dokładność Graph-uregulowana modelu jest o 2-3% wyższa niż w przypadku modelu bazowej ( base_model ).
Wniosek
Zademonstrowaliśmy zastosowanie regularyzacji grafów do klasyfikacji dokumentów na grafie cytowań naturalnych (Cora) przy użyciu struktury Neural Structured Learning (NSL). Nasz zaawansowany samouczek wymaga syntetyzowania wykresy na podstawie próbki zanurzeń przed treningiem sieć neuronową z wykresu uregulowania. To podejście jest przydatne, jeśli dane wejściowe nie zawierają wyraźnego wykresu.
Zachęcamy użytkowników do dalszych eksperymentów, zmieniając zakres nadzoru, a także próbując różnych architektur neuronowych w celu regularyzacji grafów.
Zobacz na TensorFlow.org
Uruchom w Google Colab
Wyświetl źródło na GitHub
Pobierz notatnik