Moving from Lab to Production: Deploying Prompt Classification Models with TorchServe
In our previous blogpost, we delved into the world of prompt classification models, putting them to the test in an academic setting. It was a deep dive into understanding their capabilities and how they stack up against each other. Now, it's time to take the next step - transitioning these models from research to reality.
Transitioning a model to production means preparing it for real-world use. This involves optimizing the model, ensuring it can handle high volumes of requests, and setting up a system that allows it to function seamlessly within an application.
To help us with this, we'll use TorchServe, a PyTorch model serving tool specifically designed to take the complexities out of deploying models. With TorchServe, we can focus on refining our models while it takes care of the deployment process.
In this blogpost, we'll walk through deploying one of the best models from our previous study, a combination of pretrained Sentence-BERT (SBERT) and a Logistic Regression classifier.
Prerequisites
Before we start, make sure you have the following installed:
- Python
- PyTorch
- TorchServe
- sentence_transformers
- joblib
- scikit-learn
We assume you already have a trained logistic regression model (classifier.joblib) that takes SBERT embeddings as input and performs classification.
Building the .mar file
The first step in deploying a model with TorchServe is creating a .mar file, which is an archive file containing everything needed to run your model.
Let's start by preparing our model for TorchServe. For this, we create a directory that will contain our SBERT model and the trained Logistic Regression classifier.
In the following bash script (buildmar_sbert.sh), we're achieving a few things:
- We're creating a directory named after our SBERT model, and we're saving our classifier into this directory.
- Then, we load our pretrained SBERT model and the classifier, and save the SBERT model in the same directory.
- Next, we're building a list of all the files in our model directory (which includes the SBERT model and the classifier).
- Finally, we're creating a TorchServe Model Archive (.mar file) using the torch-model-archiver tool, and we're starting TorchServe with our new model.
#!/bin/bash
# export variables for inline python script
export DIRNAME="sbert-lr-paraphrase-distilroberta-base-v1"
export SBERT_NAME="paraphrase-distilroberta-base-v1"
export CLASSIFIER="classifier.joblib"
export PREFIX="sbert-lr"
# prepare folder to which the models will be saved
mkdir -p $DIRNAME
# save classifier model joblib dump
cp $CLASSIFIER $DIRNAME
# load pretrained SBERT model and save it to the folder
python3 - <<EOF
import os
import joblib
import torch
from sentence_transformers import SentenceTransformer
model_dir = os.environ['DIRNAME']
classifier = os.environ['CLASSIFIER']
sbert_mode_name = os.environ['SBERT_NAME']
sbert_model = SentenceTransformer(sbert_mode_name)
loaded_classifier = joblib.load(classifier)
sbert_model.save(model_dir)
print(f"Saved model to {model_dir}")
EOF
# names for model and mar files
model=$PREFIX-model
mar=$PREFIX-model.mar
# include all files needed for SBERT and LR classifier
FILES=$(find $DIRNAME)
EXTRA_FILES=$(echo $FILES | sed 's/ /,/g')
echo $EXTRA_FILES
# create folders for model storage and logs
mkdir -p sbert_store
mkdir -p logs
# clear previous logs, archive model to mar file
rm logs/* ; torch-model-archiver --model-name $model --version 1.0 --handler sbert_handler.py --extra-files $EXTRA_FILES,sbert_handler.py -f && cp $model.mar sbert_store/$mar
# to run on cpu on a gpu machine:
#echo number_of_gpus=0>config.properties
#CUDA_VISIBLE_DEVICES="" torchserve --start --ncs --model-store sbert_store --models model=$mar --ts-config config.properties
# to run on gpu / machine w/o gpu
torchserve --start --ncs --model-store sbert_store --models sbert-lg-categorizer=$mar
# to stop torchserve
# torchserve --stop
Note that in this script, the torch-model-archiver command does not use the serialized-file option. This is because our model consists of multiple components (SBERT and Logistic Regression) which are loaded separately in the model handler, and because the SBERT loading function is able to correctly initialize all the necessary components directly from the directory.
Preparing the model handler
Before we run our script to pack our model into a .mar file, we need to create a handler to process incoming data and interact with the model. The handler is a Python script that provides specific methods, including initialize, preprocess, inference, and postprocess, that dictate how data is managed within your model.
For our prompt classification task, our handler will accomplish the following:
- In initialize, it sets up the device for inference (CPU / GPU), loads the Sentence-BERT model and the Logistic Regression model from the model directory.
- preprocess transforms the raw input data into a format the model can understand. In our case, it extracts the text prompts from the input data.
- inference generates embeddings from the preprocessed data using Sentence-BERT, and then conducts inference using the Logistic Regression model.
- postprocess transforms the model's output into a format that can be returned as a response to the client. Here, it converts the predicted class indices into a list.
Let's take a look at the Python script for the handler (sbert_handler.py):
import joblib
from sentence_transformers import SentenceTransformer
import torch
from torch import Tensor
from ts.torch_handler.base_handler import BaseHandler
import logging
import os
logger = logging.getLogger(__name__)
try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
except ImportError as error:
XLA_AVAILABLE = False
class ModelHandler(BaseHandler):
def __init__(self):
super().__init__()
def postprocess(self, data:Tensor):
data = data.tolist()
return data
def inference(self, data, *args, **kwargs):
encoded_data = self.model.encode(data, convert_to_tensor=True,
device=self.map_location)
results = self.classifier.predict(encoded_data.cpu().detach().numpy())
return results
def preprocess(self, data):
text_batch = []
for line in data:
text = line.get("data") or line.get("body")
# Decode text if not a str but bytes or bytearray
if isinstance(text, (bytes, bytearray)):
text = text.decode("utf-8")
text_batch.append(text)
return text_batch
def initialize(self, context):
properties = context.system_properties
if torch.cuda.is_available() and properties.get("gpu_id") is not None:
self.map_location = "cuda"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
)
elif XLA_AVAILABLE:
self.device = xm.xla_device()
else:
self.map_location = "cpu"
self.device = torch.device(self.map_location)
self.manifest = context.manifest
model_dir = properties.get("model_dir")
self.model = SentenceTransformer(model_dir)
self.model.to(self.device)
logger.debug("SBERT Model file %s loaded successfully on %s", model_dir, self.map_location)
# Load class mapping for classifiers
self.classifier = joblib.load(os.path.join(model_dir, "classifier.joblib"))
logger.debug("Classifier Model file loaded successfully")
self.initialized = True
Starting TorchServe
With the script to pack the .mar file and the handler script, we are now ready to pack our model and start TorchServe.
./buildmar_sbert.sh
The script will serve our model with the command:
torchserve --start --ncs --model-store sbert_store --models sbert-lg-categorizer=$mar
Once TorchServe is ready, we can test our endpoint:
curl -d '[{"data": "How to draw a portrait?"}' -H "Content-Type: text/plain" http://localhost:8080/predictions/sbert-lg-categorizer
Arts & Crafts
We can see that our model has successfully predicted the prompt category “Arts & Crafts” for our prompt “How to draw a portrait?”.
For serving the SetFit model, we follow a similar approach. Although the specifics of preprocessing, inference, and post-processing vary according to the nature of the model, the general methodology of setting up the handler and using TorchServe remains the same.
Keep in mind that in the example provided, we are working with TorchServe in its default mode, known as eager execution mode. While this is more than capable of serving models in most real-world scenarios, there may be instances where you seek enhanced performance or have specific requirements that are better suited to a more optimized execution strategy.
In such cases, TorchServe also supports serving models that are converted to the TorchScript format, a more performant and portable representation of PyTorch models. By using TorchScript, your model can benefit from ahead-of-time compilation, which can significantly improve execution times by allowing optimizations such as fusion of multiple operations and pre-computation of results that can be calculated at compile time.
Deployment on a Kubernetes Cluster
In this section we will
- create an entrypoint script entry.sh for our pod to start TorchServe
- create a Dockerfile and build the docker image
- push this image to a repository and run it on the Google Kubernetes Engine
In many cases, it is possible to use the default TorchServe docker image and simply add the .mar model file to the container. This is a viable approach, especially when the model doesn't have any special dependencies or functionality outside of what the TorchServe image provides. However, in our case, we will create our own Dockerfile. We'll therefore first adapt the entrypoint script entry.sh which will start the TorchServe when the Docker container is run:
#!/bin/bash
set -e
if [[ "$1" = "serve" ]]; then
shift 1
torchserve --start --ncs --model-store sbert_store --models sbert-lg-categorizer=sbert-lg-model.mar
else
eval "$@"
fi
# prevent docker exit
tail -f /dev/null
This script will check if the first argument passed to the script is serve
. If so, it will start TorchServe, otherwise it will evaluate the arguments as a command. The tail -f /dev/null
command at the end is to prevent the Docker container from exiting after the TorchServe is started.
The Dockerfile could look similar to this example. Note that you have to prepare the requirements.txt file to install the required python packages.
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel
RUN echo "nameserver 8.8.8.8" | tee /etc/resolv.conf > /dev/null
RUN apt-get clean && apt-get update
RUN DEBIAN_FRONTEND="noninteractive" TZ="Europe/London" apt-get -y install tzdata
RUN apt-get -y install software-properties-common \
build-essential libncursesw5-dev libssl-dev \
libsqlite3-dev tk-dev libgdbm-dev libc6-dev libbz2-dev \
libffi-dev zlib1g-dev openjdk-11-jdk libreadline-gplv2-dev
RUN add-apt-repository -y ppa:deadsnakes/ppa && apt-get update
RUN apt-get -y install python3.11 python3.11-distutils
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
WORKDIR /workspace/
RUN mkdir -p /workspace/sbert_store
COPY sbert-lg-model.mar /workspace/sbert_store/sbert-lg-model.mar
COPY requirements.txt /workspace/
COPY config.properties /workspace/
COPY entry.sh /workspace/entry.sh
RUN chmod +x /workspace/entry.sh
RUN pip3 install -r /workspace/requirements.txt
EXPOSE 8080 8081 8082
RUN useradd --uid 10000 runner
RUN chown -R runner:runner /workspace/
USER 10000
ENTRYPOINT ["/workspace/entry.sh"]
CMD ["serve"]
The next step is to build the image and push it to a Google Cloud container registry.
export CLUSTER=mycluster
export REGION=myregion
export PROJECT_ID=myproject
export REPO=myrepo
export IMAGE=categorizer-sbert-lg
export VERSION=v1
export IMAGEPATH=${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPO}/${IMAGE}:${VERSION}
docker build -t ${IMAGEPATH} .
gcloud config set project ${PROJECT_ID}
config set compute/zone ${REGION}
gcloud container clusters get-credentials ${CLUSTER} --zone ${REGION}
docker push ${IMAGEPATH}
Next, we're creating a Kubernetes deployment from this image.
cat >deployment.yaml <<EOL
apiVersion: apps/v1
kind: Deployment
metadata:
name: ${IMAGE}
spec:
replicas: 1
selector:
matchLabels:
run: ${IMAGE}
template:
metadata:
labels:
run: ${IMAGE}
spec:
containers:
- name: ${IMAGE}
image: ${IMAGEPATH}
imagePullPolicy: Always
resources:
limits:
nvidia.com/gpu: 1
ports:
- containerPort: 8080
EOL
kubectl create -f deployment.yaml
After creating a port-forward to the pod, we can test our deployed TorchServe categorizer.
pod_name=$(kubectl get pods -o jsonpath='{.items[*].metadata.name}' | tr ' ' '\n' | grep ${IMAGE} | head -n 1)
kubectl port-forward pod/${pod_name} 8080:8080 &
curl -d '[{"data": "How to draw a circle?"}' -H "Content-Type: text/plain" http://127.0.0.1:8080/predictions/sbert-lg-categorizer
# Arts & Craft