In this tutorial we will show you how to build a custom text classification model with AWS Comprehend. For this project, we will work with the SMS Spam Collection Dataset obtained by UCI Machine Learning Repository. Notice that we have replaced the HAM class with 0 and the SPAM class with 0 but this is not necessary.
Train the Model
In order to train a custom classification model with Comprehend, you need to sign in to the AWS Console and then go to the Comprehend Service and finally choose the “Custom classification”.
Then you click on the “Create new model” and you define the name of your model
For the “Data specifications” we choose “Classifier mode” and “CSV file” format. Note that your data should not contain a header. Finally, we add the S3 bucket location of our data and for the test dataset we choose “Autosplit”.
For the Output data we choose a new location by adding a folder to our bucket.
Then you choose the IAM role and you click on the “Create” button at the bottom right.
The Trained Model
The model took around 40 minutes to be trained, of course, it depends on the size of your data too. Once the model is trained, we can find it under the “Classifier models”
When we click on the “sms-ham-spam” we can find some statistics.
As we can see the model did really well (high accuracy, precision, recall, F1-Score)
Finally, under the S3 bucket that we have set to get the output of the model, we can find the “confusion matrix” of the model on the test dataset.
Real-Time Analysis
For the “real-time analysis” we will need to create an endpoint. Thus, we go at the “Endpoints” tab and we choose “Create endpoint“
Then, we define the “Endpoint name” and for the number of IU we choose 1. Once you follow these steps, the endpoint will be created and can be found at the “Endpoints”.
Now we are ready for the “Real-Time Analysis”. Go to the “Real Analysis” section and choose “Custom” and enter the name of the endpoint. Finally, in the input text enter the text that you want to get the predictions.
As we can see, the example that we entered appears to be a “SPAM” with 99% confidence.
Real-Time Analysis with SDK for Python
We can use the Python SDK as well as the AWS CLI to get call the custom model. Let’s see how we can do it using Boto3. First, you will need to copy the ARN of the endpoint. We will try the following input
Lol your always so convincing.
import boto3 endpoint = 'arn:aws:comprehend:region:account-id:document-classifier-endpoint/sms-ham-spam-endpoint' session = boto3.session.Session(profile_name='sandbox') client = session.client('comprehend') mytxt = "Lol your always so convincing." response = client.classify_document(Text=mytxt, EndpointArn=endpoint) response
Output
{'Classes': [{'Name': '0', 'Score': 0.9999861717224121},
{'Name': '1', 'Score': 1.3849913557351101e-05}],
'ResponseMetadata': {'RequestId': 'f8c7d94a-a23c-4fc0-9cf5-7229f1387fb6',
'HTTPStatusCode': 200,
'HTTPHeaders': {'x-amzn-requestid': 'f8c7d94a-a23c-4fc0-9cf5-7229f1387fb6',
'content-type': 'application/x-amz-json-1.1',
'content-length': '96',
'date': 'Wed, 06 Apr 2022 13:14:41 GMT'},
'RetryAttempts': 0}}
If you just want to extract the score of the first class:
response['Classes'][0]['Score']
0.9999861717224121
To run real-time analysis using a custom model (AWS CLI) you can run:
aws comprehend classify-document \ --endpoint-arn arn:aws:comprehend:region:account-id:document-classifier-endpoint/sms-ham-spam-endpoint \ --text 'Lol your always so convincing.'
The above example is formatted for Unix, Linux, and macOS. For Windows, replace the backslash () Unix continuation character at the end of each line with a caret (^).
Don’t Forget to Delete the Endpoint
Beware that you get charged as long as the endpoint is running. So, once you are done, do not forget to delete it.