In this tutorial, we will show you how to get the pooled BERT embeddings of an input text. The required steps are:
- Install the
tensorflow
- Load the BERT model from TensorFlow Hub
- Tokenize the input text by converting it to ids using a preprocessing model
- Get the pooled embedding using the loaded model
Let’s start coding.
pip install --quiet "tensorflow-text==2.8.*" import numpy as np import tensorflow as tf import tensorflow_hub as hub import tensorflow_text as text # Imports TF ops for preprocessing.
Then, we have to configure the model.
# Define the model BERT_MODEL = "https://tfhub.dev/google/experts/bert/wiki_books/2" # Choose the preprocessing that must match the model PREPROCESS_MODEL = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
Load the model and the preprocessing.
preprocess = hub.load(PREPROCESS_MODEL) bert = hub.load(BERT_MODEL)
Build a function that takes as input a raw text and it returns the BERT embeddings.
def text_to_emb(input_text): input_text_lst = [input_text] inputs = preprocess(input_text_lst) outputs = bert(inputs) return np.array((outputs['pooled_output']))
We are done, let’s run an example.
text_to_emb('This is a sample sentence')
Output:
array([[ 0.6729589 , 0.16986154, 0.05168653, 0.4964849 , 0.57049793,
0.3032315 , -0.07513454, -0.99058324, 0.9418068 , 0.35635224,
-0.90153414, -0.9523835 , -0.23299307, 0.38784108, -0.39498374,
-0.21850415, 0.23773886, 0.9872731 , -0.10500205, 0.7786572 ,
-0.29292807, 0.7102717 , 0.34709233, -0.5729018 , 0.9419572 ,
0.37320367, -0.5142145 , 0.85438424, 0.20007657, -0.3199613 ,
0.19322802, -0.261226 , -0.2806239 , 0.1782754 , -0.01885223,
0.53534544, -0.60386556, 0.74343383, -0.90285224, 0.05747601,
0.4736311 , 0.22683787, 0.22758687, -0.18882924, 0.74393773,
-0.16072457, 0.27666202, 0.6570725 , 0.6005741 , -0.38493893,
...
0.04461253, 0.48017693, 0.7973305 , 0.7892748 , -0.58761233,
0.5698995 , -0.7856279 , 0.54465777, 0.3129528 , 0.1456411 ,
0.9632164 , 0.60263646, -0.43343565, 0.82129925, 0.3692738 ,
0.2609335 , -0.4036055 , -0.06602973, 0.9456645 , 0.64760864,
-0.81159633, 0.16129227, -0.8264965 , -0.5995479 , -0.0258658 ,
0.27349144, -0.8450804 , -0.39394245]], dtype=float32)
Keep in mind that the shape of the pooled embeddings is (1,768).