Fine-Tuning Stable Diffusion XL Using AptAI APIs

Using AptAI APIs you can fine-tune Stable Diffusion XL (SDXL) – one of Stability AI’s most powerful open source models. After creating an AptAI API project for fine-tuning SDXL and gaining access to its server as described here, you can follow the steps on this page for fine-tuning an SDXL model using python. Alternatively, after logging-in you can use this readme to use the /docs page for the same purpose. Make sure to replace the values inside ‘<>’ with appropriate values.

Step 1: Upload your training images

To fine-tune SDXL you are going to need at least a few images (three to four images would suffice in many cases). Here is an example code for uploading an image. You can find directions for finding your API Key and Private Server URL here. It is worth noting that you can use the above endpoint to upload images that you want to use with img2img tasks as well.

Python
import io
import json
import time
import requests
from PIL import Image
from urllib.parse import urljoin

header={"x-api-key" : "<YOUR-API-KEY>"}
base_url = "<YOUR-PRIVATE-SERVER-URL>"


files = [
     ('file', open('<PATH-TO-YOUR-TRAINING-IMAGE1>','rb')), 
     ('file', open('<PATH-TO-YOUR-TRAINING-IMAGE2>','rb')), 
]

url = urljoin(base_url, '/api/v1/upload_multiple_files')

r = requests.post(url=url, files=files, headers=header)

if r.status_code != 200:
    print("Error uploading file...")
    print(r.json())

image_uuids = r.json()['file_uuids']
print("image_uuids:", image_uuids)

Step 2: Start a fine-tuning task

Once the files are uploaded to your server, you can start a fine-tuning task using these files. The fine-tuning can take from about five minutes to more than an hour based on the number of images you use for the training, the settings of your private server such as its GPU and the parameters you set especially the number of steps. Make sure to include a unique keyword to refer to the object you are training, for instance, ‘my_obj1’.

Pro tip: This API endpoint supports prompt weighting. You can use “+” and “-” to emphasize or de-emphasize words (or groups of words in parentheses) in a sentence.

Python
values = {
            "lora_name": "<AN-ARBITRARY-NAME-FOR-YOUR-LORA-MODEL>",
            "base_model_name": "stabilityai/stable-diffusion-xl-base-1.0",
            "pretrained_vae_model_name_or_path": "madebyollin/sdxl-vae-fp16-fix",
            "input_images_uuids": image_uuids,
            "instance_prompt": "<PROMPT-DESCRIBING-THE-INPUT>",
            "keywords": "<YOUR-UNIQUE-KEYWORD>",
            "resolution": 1024,
            "batch_size": 1,
            "gradient_accumulation_steps": 4,
            "learning_rate": 0.0001,
            "lr_warmup_steps": 0,
            "lr_scheduler": "linear",
            "max_train_steps": 500,
            "validation_epochs": 100,
            "checkpointing_steps": 250,
            "is_experimental": 'true',
            "train_text_encoder": 'true'
            }
url = urljoin(base_url, '/api/v1/train')
r = requests.post(url, data=json.dumps(values), headers=header)

if r.status_code != 202:
    print("Error in train...")
    print(r.json())

time.sleep(10)

task_uid = r.json()["task_uid"]

For instance, if you want your model to learn a new face, you can use something like the following:

Python
...
            "instance_prompt": "a photo of sbs1 man.",
            "keywords": "sbs1",
...

Step 3: Check the status of the fine-tuning task

Once the task starts, you can track its progress. In order to do so, you need the id of the task that you just created which is returned in the response of the previous request.

Python
while 1:
  url = urljoin(base_url, '/api/v1/tasks/progress')
  values = {"task_uid": task_uid}
  try:
    r = requests.post(url, data=json.dumps(values), headers=header, timeout=1)
    print(r.json())
    if r.status_code == 200:
      if r.json()["progress"] == 100:
          break
    time.sleep(5)
  except: 
     pass
    
time.sleep(10) # Waiting for the LoRA model to get prepared

Once the task is complete you should be able to retrieve the information of the generated LoRA model.

Python
url = urljoin(base_url, '/api/v1/tasks/result')
values = {"task_uid": task_uid}
r = requests.post(url, data=json.dumps(values), headers=header)
time.sleep(2)
print(r.json())

lora_uuid = r.json()["uuid"]

In case you decide to interrupt (i.e. stop) the training process you can use the following code:

Python
url = urljoin(base_url, '/api/v1/tasks/interrupt')
values = {"task_uid": task_uid}
r = requests.post(url, data=json.dumps(values), headers=header)
time.sleep(2)
print(r.json())

To see the statuses and results of all your fine-tuning tasks you can use this code:

Python
url = urljoin(base_url, '/api/v1/tasks/all')
r = requests.post(url, headers=header)
print(r.json())

Step 4: Generate images (txt2img)

Once you have the UUID of your LoRA model, you can use it for txt2img and img2img tasks. Here is an example of txt2img.

Python
from PIL import Image
values = {
  "lora_uuid": lora_uuid,
  "base_model_name": "stabilityai/stable-diffusion-xl-base-1.0",
  "vae_model_name": "madebyollin/sdxl-vae-fp16-fix",
  "prompt": "<YOUR-PROMPT>",
  "negative_prompt": "<YOUR-NEGATIVE-PROMPT>:,
  "height": 1024,
  "width": 1024,
  "num_inference_steps": 15,
  "guidance_scale": 7.5,
  "num_images_per_prompt": 1,
  "eta": 0,
  "seed": 0
}
url = urljoin(base_url, '/api/v1/txt2img')
r = requests.post(url, data=json.dumps(values), headers=header, stream=True) # as r:
img = Image.open(io.BytesIO(r.content))
img.show()

The prompt and negative prompt play critical roles in the quality and consistency of the images you are trying to create. Here are a couple of examples for creating professional headshots as an example:

Python
  "prompt": "8k intricate, highly detailed, digital photography, best quality+, perfect++ eyes, masterpiece, A professional headshot of 35 year-old (sxs man)+++ dark++ brown eyes smooth skin in a (black suit)+ necktie, cinematic lighting modern",
  "negative_prompt": "(deformed iris, deformed pupils)+, text, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, (extra fingers)+, (mutated hands)+, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, (fused fingers)+, (too many fingers)+, long neck, camera",

Optional Step: Upload your LoRA to the cloud storage

The LoRA model that you just trained will be only accessible by the private server you are currently using. If you want to use this LoRA models with your other AptAI private servers, you can optionally upload your LoRA model to our cloud storage (it will incur a small monthly cost).

Python
values = {
  "lora_uuid": lora_uuid,
}
url = urljoin(base_url, '/api/v1/upload_lora_model')
r = requests.post(url, data=json.dumps(values), headers=header, stream=True)
print(r.json())