In this tutorial, we will show you how to build a public Rest API that calls the SageMaker Endpoint. The tutorial consists of the following parts.
- Create the SageMaker Model
- Deploy the SageMaker Endpoint
- Create a Lambda Function
- Create an API Gateway
- Call the Rest API via Postman
The SageMaker Model and Endpoint
First, you will need to log in to the AWS console. Then you can go to SageMaker and create a new Notebook instance. Once you open the Jupyter notebook, you can go under the SageMaker Examples and choose the “Breast Cancer Prediction.ipynb“. Then you can comment on the final cell and run the notebook. Then, when we are done with the tutorial, you can delete the endpoint.
If you go to Inference–>Endpoints you will see the name of the end point that you created.
Lambda Function
Once you are done with the SageMaker Endpoint, you are ready to move on with the Lambda function. Thus, you go to the “services” and search for “lambda” and then you click on the “Create function” orange button.
Then you can set the function name that you want and for the runtime you can select Python 3.7.
Finally, the most important thing is to add a policy to the new role that allows them to invoke the endpoint. The policy is:
{ "Version": "2012-10-17", "Statement": [ { "Sid": "VisualEditor0", "Effect": "Allow", "Action": "sagemaker:InvokeEndpoint", "Resource": "*" } ] }
Now, we are ready to write the lambda function.
import os import io import boto3 import json import csv # grab environment variables ENDPOINT_NAME = os.environ['ENDPOINT_NAME'] runtime= boto3.client('runtime.sagemaker') def lambda_handler(event, context): data = json.loads(json.dumps(event)) payload = data['data'] response = runtime.invoke_endpoint(EndpointName=ENDPOINT_NAME, ContentType='text/csv', Body=payload) result = json.loads(response['Body'].read().decode()) pred = int(result['predictions'][0]['score']) predicted_label = 'M' if pred == 1 else 'B' return predicted_label
Note that the ENDPOINT_NAME
is an environment variable, where in this case it takes as value the name of the SageMager endpoint of the previous step. In order to add the environment variable, you go to the “Configuration” tab and you choose “Environment variables”.
You choose “add environment variable” and as a key, you type “ENDPOINT_NAME” and as a value, you pass the name of the SageMaker Endpoint.
Create the API Gateway
Once you are done with the lambda function, you can start building the Gateway API. You can go to services and search for “API Gateway”. Then we choose “REST API” and “Build”. Then you choose “New API”, we give a name and you click on the “Create API”.
Then you go to Actions and you choose “Create resource” and you enter a name for the resource. Finally, you create a post method (Actions–>Create Method–>Post) and for the integration type, you select “Lambda Function” passing the name of the lambda function that you created before. Finally, you go again to the Actions menu and you select “Deploy API” creating a new stage passing an arbitrary name.
If you go back to the lambda function, you will see that the API Gateway has been added. From there you can find the URL that we can use with Postman.
Note that we did not restrict the API to VPC neither we added credentials. In other words, the API is public and anyone who has the link is able to call it. Bear in mind that this is a serverless API and can be used for serverless applications.
Call the Rest API
We can call the rest API using Postman.
{"data":"13.49,22.3,86.91,561.0,0.08752,0.07697999999999999,0.047510000000000004,0.033839999999999995,0.1809,0.057179999999999995,0.2338,1.3530000000000002,1.735,20.2,0.004455,0.013819999999999999,0.02095,0.01184,0.01641,0.001956,15.15,31.82,99.0,698.8,0.1162,0.1711,0.2282,0.1282,0.2871,0.06917000000000001"}
We entered the above example and we got back “B” as a response, where “B” comes from “benign”.
We can also call the API with Python as follows:
import requests url = 'https://xxxxxxx.execute-api.xx-xxxx-x.amazonaws.com/test/predictbreastcancer' input_data = "13.49,22.3,86.91,561.0,0.08752,0.07697999999999999,0.047510000000000004,0.033839999999999995,0.1809,0.057179999999999995,0.2338,1.3530000000000002,1.735,20.2,0.004455,0.013819999999999999,0.02095,0.01184,0.01641,0.001956,15.15,31.82,99.0,698.8,0.1162,0.1711,0.2282,0.1282,0.2871,0.06917000000000001" input_data = json.dumps({'data':input_data}) r = requests.post(url, data=input_data) r.json()
'B'