How to use a Custom Embedding Model or LLM with Arguflow
The LLM is the easiest thing to replace, all you have to do is set the OPENAI_BASE_URL
to an OpenAI compliant API endpoint for text completions using a gpt-3.5-turbo
model. The easiest way to do this is by self-hosting LocalAI.
This guide will demonstrate how to build a custom embedding server for the Arguflow API. This will allow you to use your own models for generating embeddings and not use OpenAI.
Prerequisites
- A server with a GPU is required to run the custom embedding server. If you do not have access to a GPU, you can use a cloud service such as Google Colab.
- Python 3.9
- PyTorch
- HuggingFace Transformers
- Flask
Step 1: Create a new Python file
Create a new Python file called embedding_server.py
and import the following packages:
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModel
import torch
Step 2: Load the model
Load the model and tokenizer. For this example, we will use the bert-base-uncased
model.
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
Step 3: Create the Flask app
Create the Flask app and add a route for the /embed
endpoint. This endpoint will be used to generate embeddings for the Arguflow API.
app = Flask(__name__)
@app.route("/embed", methods=["POST"])
def embed():
return jsonify({"embeddings": []})
Step 4: Generate embeddings
Now we can generate embeddings for the Arguflow API. The Arguflow API will send a POST request to the /embed
endpoint with the following JSON body:
{
"sentences": ["This is the first sentence.", "This is the second sentence."]
}
We can use the sentences
field to generate embeddings for each sentence. We will use the bert-base-uncased
model to generate the embeddings.
@app.route("/embed", methods=["POST"])
def embed():
sentences = request.json["sentences"]
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.detach().numpy().tolist()
return jsonify({"embeddings": embeddings})
Step 5: Run the server
Now we can run the server and test it out. Run the following command to start the server:
python embedding_server.py
You should see the following output:
* Serving Flask app 'embedding_server' (lazy loading)
* Environment: production
* Debug mode: off
* Running on http://
Now we can test the server by sending a POST request to the /embed
endpoint. You can use the following command to test the server:
curl -X POST -H "Content-Type: application/json" -d '{"sentences": ["This is the first sentence.", "This is the second sentence."]}'
You should see the following output:
{
"embeddings": [
[
[
-0.3405719699859619,...
],
]
]
}
Step 6: Call the API from the Arguflow API
Now that we have a custom embedding server, we can use it to generate embeddings for the Arguflow API. To do this, we need to update the .env
file to use the custom embedding server. Open the .env
file and update the embedding_server
field to use the custom embedding server:
EMBEDDING_SERVER_CALL = <Link to embedding server>