This lesson explores Flask basics and demonstrates serving a machine learning model using Flask for model deployment.
In this lesson, we explore the basics of Flask and demonstrate how to serve a machine learning model using Flask. This guide combines content from a Jupyter Notebook and a standalone Flask application file, providing a comprehensive introduction to model deployment with Flask.
Before building the application, ensure that Flask is installed. Run the following commands to install Flask, check its version, and inspect the directory structure of your Flask app:
Copy
Ask AI
!pip install Flask
Copy
Ask AI
bash!python -m flask --version
Copy
Ask AI
bash!tree flask_app/
These commands not only install Flask but also verify that the essential files exist within the flask_app/ directory, an important part of your model deployment workflow.If Flask is already installed, you may see output indicating that the requirements are already satisfied, for example:
Copy
Ask AI
!pip install Flask
Copy
Ask AI
Requirement already satisfied: Flask in /root/venv/lib/python3.11/site-packages (3.1.0)Requirement already satisfied: Werkzeug>=3.1 in /root/venv/lib/python3.11/site-packages (from Flask) (3.1.3)Requirement already satisfied: Jinja2>=3.1.2 in /root/venv/lib/python3.11/site-packages (from Flask) (3.1.5)Requirement already satisfied: itsdangerous>=2.2 in /root/venv/lib/python3.11/site-packages (from Flask) (2.2.0)Requirement already satisfied: click>=8.1.3 in /root/venv/lib/python3.11/site-packages (from Flask) (8.1.8)Requirement already satisfied: blinker>=1.9 in /root/venv/lib/python3.11/site-packages (from Flask) (1.9.0)Requirement already satisfied: MarkupSafe>=2.0 in /root/venv/lib/python3.11/site-packages (from Jinja2>=3.1.2->Flask) (3.0.2)
You can also re-run the version command from within the Notebook:
Copy
Ask AI
python -m flask --version
This command displays your Python version, Flask version (3.1.0), and Werkzeug version (3.1.3).Next, verify the structure of your Flask app. Running:
Before starting the Flask server, initialize the app by loading any required environment variables and your machine learning model. In this example, we use the MobileNetV3 Large pre-trained model. It is essential that the model is loaded before the application processes any requests.Below is an example of the initial setup with logging and error handling:
Copy
Ask AI
import osimport ioimport base64import jsonimport loggingfrom flask import Flask, request, jsonifyfrom torchvision import modelsimport torchfrom PIL import Imagefrom image_transforms import preprocess # Ensure this module is available# Initialize Flask appapp = Flask(__name__)# Set up logginglogging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')logger = logging.getLogger(__name__)# Load environment variables or secretsMY_SECRET = os.getenv('SECRET')# Load the MobileNetV3 Large pre-trained model before starting the apptry: logger.info("Loading MobileNetV3 Large pre-trained model...") model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.DEFAULT) model.eval() # Switch to evaluation mode logger.info("Model loaded successfully.")except Exception as e: logger.error(f"Error loading model: {str(e)}") raise RuntimeError("Failed to load the model.") from e
Make sure that all required modules are imported and logging is correctly configured. The model must be loaded before any request is processed to avoid runtime errors.
The /predict endpoint handles POST requests. It accepts a JSON payload that contains an image encoded in Base64. This endpoint decodes the image, preprocesses it, performs inference using the model, and returns the prediction in JSON format.
Copy
Ask AI
@app.route('/predict', methods=['POST'])def predict(): try: # Extract Base64 string from the incoming JSON request data = request.json if not data or 'image' not in data: logger.warning("No image provided in the request.") return jsonify({'error': 'No image provided'}), 400 # Decode the Base64 image string image_data = base64.b64decode(data['image']) image = Image.open(io.BytesIO(image_data)).convert('RGB') # Preprocess the image and add the batch dimension if required transformed_img = preprocess(image).unsqueeze(0) # Perform inference in a no_grad context to save memory with torch.no_grad(): logger.info("Performing inference...") output = model(transformed_img) _, predicted = torch.max(output.data, 1) logger.info(f"Inference complete. Predicted class: {predicted.item()}") # Return the prediction as a JSON response response = {'prediction': predicted.item()} logger.info(f"Response for /predict: {response}") return jsonify(response) except Exception as e: logger.error(f"Error during prediction: {str(e)}") response = {'error': str(e)} logger.info(f"Response for /predict: {response}") return jsonify(response), 500
This endpoint performs the following steps:
Parses the request payload and verifies the presence of an "image" key.
Decodes the Base64-encoded image and converts it into an RGB image.
Applies image preprocessing before passing the tensor to the model.
Retrieves and returns the prediction using Flask’s jsonify method.
The /health endpoint is a simple GET endpoint used to verify that the server is running correctly. It returns a JSON response with a health status.
Copy
Ask AI
@app.route('/health', methods=['GET'])def health(): """ Health check endpoint to confirm the app is running. """ response = {'status': 'healthy'} logger.info(f"Response for /health: {response}") return jsonify(response), 200
To start the Flask application, run the following command from your terminal:
Copy
Ask AI
python app.py
This command initializes the app, loads the model, and starts a development server, typically accessible at http://127.0.0.1:5000.Example terminal output:
Copy
Ask AI
root@pytorch demos/040-040-introduction-to-flask/flask_app on [] main [!?] via 🐍 v3.11.4 (venv) → python app.py2025-01-15 01:36:46,774 - INFO - Loading MobileNetV3 Large pre-trained model...2025-01-15 01:36:46,912 - INFO - Model loaded successfully.* Serving Flask app 'app'* Debug mode: on2025-01-15 01:36:46,919 - INFO - WARNING: This is a development server. Do not use it in a production deployment.* Running on http://127.0.0.1:50002025-01-15 01:36:46,919 - INFO - Press CTRL+C to quit
Do not use the Flask development server in a production environment. For production deployments, consider using a WSGI server such as Gunicorn.
Test error handling by sending requests without the required payload or using an incorrect key:
Copy
Ask AI
# Test without sending any payloaderror_response = requests.post("http://127.0.0.1:5000/predict", headers=headers)print("Status Code:", error_response.status_code)print("Response JSON:", error_response.json())# Test with an incorrectly formatted payloaderror_response = requests.post("http://127.0.0.1:5000/predict", json={"video": base64_string}, headers=headers)print("Status Code:", error_response.status_code)print("Response JSON:", error_response.json())
The first case should return a 500 error (e.g., failure to decode JSON), while the second returns a 400 status with a message indicating that no image was provided.
For production deployments, use a robust WSGI server like Gunicorn. Start the Gunicorn server with the following command:
Copy
Ask AI
gunicorn -w 2 -b 0.0.0.0:8080 app:app
Test the application on port 8080:
Copy
Ask AI
# Send a POST request using Gunicorn on port 8080response = requests.post("http://127.0.0.1:8080/predict", json=payload, headers=headers)print("Status Code:", response.status_code)print("Response JSON:", response.json())
Terminal logs should display messages similar to:
Copy
Ask AI
2025-01-15 01:36:46,774 - INFO - Loading MobileNetV3 Large pre-trained model...2025-01-15 01:36:46,912 - INFO - Model loaded successfully....2025-01-15 01:36:49,478 - INFO - * Debugger PIN: 808-753-3422025-01-15 01:39:17,671 - INFO - Response for /predict: {'prediction': 207}
To convert the numeric prediction (e.g., 207) into a human-readable class label, use a mapping file (labels.json) available from Hugging Face. The labels file can be downloaded from:Imagenet 1K LabelsAfter downloading the file, use the following code to interpret the prediction:
Copy
Ask AI
import jsonwith open("labels.json", "r") as f: imagenet_classes = json.load(f)# Retrieve the class name for the predicted classclass_label = imagenet_classes['207']print(class_label)
If, for instance, the prediction corresponds to a golden retriever, the output should confirm the image class as “golden retriever”—an ideal match if your input image depicts a golden retriever puppy.This concludes our introduction to Flask and model deployment. With Flask, you can quickly set up HTTP endpoints to serve machine learning models, complete with robust error handling and logging. Happy coding!