View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Overview
This tutorial focuses on preparing tf.data.Dataset
s by reading data from mongoDB collections and using it for training a tf.keras
model.
Setup packages
This tutorial uses pymongo
as a helper package to create a new mongodb database and collection to store the data.
Install the required tensorflow-io and mongodb (helper) packages
pip install -q tensorflow-io
pip install -q pymongo
Import packages
import os
import time
from pprint import pprint
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
import tensorflow_io as tfio
from pymongo import MongoClient
Validate tf and tfio imports
print("tensorflow-io version: {}".format(tfio.__version__))
print("tensorflow version: {}".format(tf.__version__))
tensorflow-io version: 0.20.0 tensorflow version: 2.6.0
Download and setup the MongoDB instance
For demo purposes, the open-source version of mongodb is used.
sudo apt install -y mongodb >log
service mongodb start
* Starting database mongodb ...done. WARNING: apt does not have a stable CLI interface. Use with caution in scripts. debconf: unable to initialize frontend: Dialog debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 76, <> line 8.) debconf: falling back to frontend: Readline debconf: unable to initialize frontend: Readline debconf: (This frontend requires a controlling tty.) debconf: falling back to frontend: Teletype dpkg-preconfigure: unable to re-open stdin:
# Sleep for few seconds to let the instance start.
time.sleep(5)
Once the instance has been started, grep for mongo
in the processes list to confirm the availability.
ps -ef | grep mongo
mongodb 580 1 13 17:38 ? 00:00:00 /usr/bin/mongod --config /etc/mongodb.conf root 612 610 0 17:38 ? 00:00:00 grep mongo
query the base endpoint to retrieve information about the cluster.
client = MongoClient()
client.list_database_names() # ['admin', 'local']
['admin', 'local']
Explore the dataset
For the purpose of this tutorial, lets download the PetFinder dataset and feed the data into mongodb manually. The goal of this classification problem is predict if the pet will be adopted or not.
dataset_url = 'http://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip'
csv_file = 'datasets/petfinder-mini/petfinder-mini.csv'
tf.keras.utils.get_file('petfinder_mini.zip', dataset_url,
extract=True, cache_dir='.')
pf_df = pd.read_csv(csv_file)
Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip 1671168/1668792 [==============================] - 0s 0us/step 1679360/1668792 [==============================] - 0s 0us/step
pf_df.head()
For the purpose of the tutorial, modifications are made to the label column. 0 will indicate the pet was not adopted, and 1 will indicate that it was.
# In the original dataset "4" indicates the pet was not adopted.
pf_df['target'] = np.where(pf_df['AdoptionSpeed']==4, 0, 1)
# Drop un-used columns.
pf_df = pf_df.drop(columns=['AdoptionSpeed', 'Description'])
# Number of datapoints and columns
len(pf_df), len(pf_df.columns)
(11537, 14)
Split the dataset
train_df, test_df = train_test_split(pf_df, test_size=0.3, shuffle=True)
print("Number of training samples: ",len(train_df))
print("Number of testing sample: ",len(test_df))
Number of training samples: 8075 Number of testing sample: 3462
Store the train and test data in mongo collections
URI = "mongodb://localhost:27017"
DATABASE = "tfiodb"
TRAIN_COLLECTION = "train"
TEST_COLLECTION = "test"
db = client[DATABASE]
if "train" not in db.list_collection_names():
db.create_collection(TRAIN_COLLECTION)
if "test" not in db.list_collection_names():
db.create_collection(TEST_COLLECTION)
def store_records(collection, records):
writer = tfio.experimental.mongodb.MongoDBWriter(
uri=URI, database=DATABASE, collection=collection
)
for record in records:
writer.write(record)
store_records(collection="train", records=train_df.to_dict("records"))
time.sleep(2)
store_records(collection="test", records=test_df.to_dict("records"))
Prepare tfio datasets
Once the data is available in the cluster, the mongodb.MongoDBIODataset
class is utilized for this purpose. The class inherits from tf.data.Dataset
and thus exposes all the useful functionalities of tf.data.Dataset
out of the box.
Training dataset
train_ds = tfio.experimental.mongodb.MongoDBIODataset(
uri=URI, database=DATABASE, collection=TRAIN_COLLECTION
)
train_ds
Connection successful: mongodb://localhost:27017 WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow/python/data/experimental/ops/counter.py:66: scan (from tensorflow.python.data.experimental.ops.scan_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.scan(...) instead WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tensorflow_io/python/experimental/mongodb_dataset_ops.py:114: take_while (from tensorflow.python.data.experimental.ops.take_while_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Dataset.take_while(...) <MongoDBIODataset shapes: (), types: tf.string>
Each item in train_ds
is a string which needs to be decoded into a json. To do so, you can select only a subset of the columns by specifying the TensorSpec
# Numeric features.
numerical_cols = ['PhotoAmt', 'Fee']
SPECS = {
"target": tf.TensorSpec(tf.TensorShape([]), tf.int64, name="target"),
}
for col in numerical_cols:
SPECS[col] = tf.TensorSpec(tf.TensorShape([]), tf.int32, name=col)
pprint(SPECS)
{'Fee': TensorSpec(shape=(), dtype=tf.int32, name='Fee'), 'PhotoAmt': TensorSpec(shape=(), dtype=tf.int32, name='PhotoAmt'), 'target': TensorSpec(shape=(), dtype=tf.int64, name='target')}
BATCH_SIZE=32
train_ds = train_ds.map(
lambda x: tfio.experimental.serialization.decode_json(x, specs=SPECS)
)
# Prepare a tuple of (features, label)
train_ds = train_ds.map(lambda v: (v, v.pop("target")))
train_ds = train_ds.batch(BATCH_SIZE)
train_ds
<BatchDataset shapes: ({PhotoAmt: (None,), Fee: (None,)}, (None,)), types: ({PhotoAmt: tf.int32, Fee: tf.int32}, tf.int64)>
Testing dataset
test_ds = tfio.experimental.mongodb.MongoDBIODataset(
uri=URI, database=DATABASE, collection=TEST_COLLECTION
)
test_ds = test_ds.map(
lambda x: tfio.experimental.serialization.decode_json(x, specs=SPECS)
)
# Prepare a tuple of (features, label)
test_ds = test_ds.map(lambda v: (v, v.pop("target")))
test_ds = test_ds.batch(BATCH_SIZE)
test_ds
Connection successful: mongodb://localhost:27017 <BatchDataset shapes: ({PhotoAmt: (None,), Fee: (None,)}, (None,)), types: ({PhotoAmt: tf.int32, Fee: tf.int32}, tf.int64)>
Define the keras preprocessing layers
As per the structured data tutorial, it is recommended to use the Keras Preprocessing Layers as they are more intuitive, and can be easily integrated with the models. However, the standard feature_columns can also be used.
For a better understanding of the preprocessing_layers
in classifying structured data, please refer to the structured data tutorial
def get_normalization_layer(name, dataset):
# Create a Normalization layer for our feature.
normalizer = preprocessing.Normalization(axis=None)
# Prepare a Dataset that only yields our feature.
feature_ds = dataset.map(lambda x, y: x[name])
# Learn the statistics of the data.
normalizer.adapt(feature_ds)
return normalizer
all_inputs = []
encoded_features = []
for header in numerical_cols:
numeric_col = tf.keras.Input(shape=(1,), name=header)
normalization_layer = get_normalization_layer(header, train_ds)
encoded_numeric_col = normalization_layer(numeric_col)
all_inputs.append(numeric_col)
encoded_features.append(encoded_numeric_col)
Build, compile and train the model
# Set the parameters
OPTIMIZER="adam"
LOSS=tf.keras.losses.BinaryCrossentropy(from_logits=True)
METRICS=['accuracy']
EPOCHS=10
# Convert the feature columns into a tf.keras layer
all_features = tf.keras.layers.concatenate(encoded_features)
# design/build the model
x = tf.keras.layers.Dense(32, activation="relu")(all_features)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(64, activation="relu")(x)
x = tf.keras.layers.Dropout(0.5)(x)
output = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(all_inputs, output)
# compile the model
model.compile(optimizer=OPTIMIZER, loss=LOSS, metrics=METRICS)
# fit the model
model.fit(train_ds, epochs=EPOCHS)
Epoch 1/10 109/109 [==============================] - 1s 2ms/step - loss: 0.6261 - accuracy: 0.4711 Epoch 2/10 109/109 [==============================] - 0s 3ms/step - loss: 0.5939 - accuracy: 0.6967 Epoch 3/10 109/109 [==============================] - 0s 3ms/step - loss: 0.5900 - accuracy: 0.6993 Epoch 4/10 109/109 [==============================] - 0s 3ms/step - loss: 0.5846 - accuracy: 0.7146 Epoch 5/10 109/109 [==============================] - 0s 3ms/step - loss: 0.5824 - accuracy: 0.7178 Epoch 6/10 109/109 [==============================] - 0s 2ms/step - loss: 0.5778 - accuracy: 0.7233 Epoch 7/10 109/109 [==============================] - 0s 3ms/step - loss: 0.5810 - accuracy: 0.7083 Epoch 8/10 109/109 [==============================] - 0s 3ms/step - loss: 0.5791 - accuracy: 0.7149 Epoch 9/10 109/109 [==============================] - 0s 3ms/step - loss: 0.5742 - accuracy: 0.7207 Epoch 10/10 109/109 [==============================] - 0s 2ms/step - loss: 0.5797 - accuracy: 0.7083 <keras.callbacks.History at 0x7f743229fe90>
Infer on the test data
res = model.evaluate(test_ds)
print("test loss, test acc:", res)
109/109 [==============================] - 0s 2ms/step - loss: 0.5696 - accuracy: 0.7383 test loss, test acc: [0.569588840007782, 0.7383015751838684]