Stable diffusion for text to Image generation

Introduction

Stable Diffusion is an open-source text-to-image generation model. While there exist multiple open-source implementations that allow you to easily create images from textual prompts, KerasCV offers a few distinct advantages. These include XLA compilation and mixed precision support, which together achieve state-of-the-art generation speed.

How stable diffusion work?

Its architecture is similar to the encoder decoder type. Stable diffusion architecture consists of three parts.

  1. Text encoder, which turns your prompt into a latent vector.
  2. Diffusion model, which repeatedly denoises a 64×64 latent image patch.
  3. Decoder, which turns the final 64×64 latent patch into a higher-resolution image.

First your text prompt gets projected into a latent vector space by the text encoder which is simply a pretrained frozen language model. Then that prompt vector is concatenate to a randomly generated noise patch which is repeatedly denoised by the decoder over a series of steps. Finally, the 64×64 latent image is sent through the decoder to properly render it in high resolution.

The Stable Diffusion architecture
source: High-performance image generation using Stable Diffusion in KerasCV

Prerequisite

Python >= 3.7
pip install matplotlib
pip install tensorflow >= 2.9.1
pip install keras-cv

Code

First import all necessary library,

import keras_cv
from tensorflow import keras
import matplotlib.pyplot as plt

then we download pretrained model and pass size of image

model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)

after successfully download model, we generate image by passing text and batch size in text_to_image and display images using matplotlib. the images will be generated as much as you enter the batch size. it will take few minutes to execute.

images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)

def plot_images(images):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")


plot_images(images)

Optimize model performance

Mixed precision consists of performing computation using float16 precision, while storing weights in the float32 format. Enabling mixed precision computation in Keras add below code before load model.

keras.mixed_precision.set_global_policy("mixed_float16")

TensorFlow comes with the XLA: Accelerated Linear Algebra compiler built-in. keras_cv.models.StableDiffusion supports a jit_compile argument out of the box. Setting this argument to True enables XLA compilation, resulting in a significant speed-up.

model = keras_cv.models.StableDiffusion(jit_compile=True)

For more information

Data Science – pytechie.com

High-performance image generation using Stable Diffusion in KerasCV

Leave a Reply