基于onnx模型和onnx runtime推理stable diffusion

[复制链接]
查看2215 | 回复0 | 2023-8-15 16:45:43 | 显示全部楼层 |阅读模式
直接用diffusers的pipeline:
  1. import os
  2. from diffusers import OnnxStableDiffusionPipeline, OnnxRuntimeModel
  3. from diffusers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, DPMSolverMultistepScheduler
  4. from transformers import CLIPTextModel, CLIPTokenizer
  5. model_dir = "/mnt/f/deep_learning/onnx_model/stable_diffusio_v1.5/"
  6. prompt = "a photo of an astronaut riding a horse on mars"
  7. num_inference_steps = 20
  8. scheduler = PNDMScheduler.from_pretrained(os.path.join(model_dir, "scheduler/scheduler_config.json"))
  9. tokenizer = CLIPTokenizer.from_pretrained(model_dir, subfolder="tokenizer")
  10. text_encoder = OnnxRuntimeModel(model=OnnxRuntimeModel.load_model(os.path.join(model_dir, "text_encoder/model.onnx")))
  11. # in txt to image, vae_encoder is not necessary, only used in image to image generation
  12. # vae_encoder = OnnxRuntimeModel(model=OnnxRuntimeModel.load_model(os.path.join(model_dir, "vae_encoder/model.onnx")))
  13. vae_decoder = OnnxRuntimeModel(model=OnnxRuntimeModel.load_model(os.path.join(model_dir, "vae_decoder/model.onnx")))
  14. unet = OnnxRuntimeModel(model=OnnxRuntimeModel.load_model(os.path.join(model_dir, "unet/model.onnx")))
  15. pipe = OnnxStableDiffusionPipeline(
  16.     vae_encoder=None,
  17.     vae_decoder=vae_decoder,
  18.     text_encoder=text_encoder,
  19.     tokenizer=tokenizer,
  20.     unet=unet,
  21.     scheduler=scheduler,
  22.     safety_checker=None,
  23.     feature_extractor=None,
  24.     requires_safety_checker=False,
  25. )
  26. image = pipe(prompt, num_inference_steps=num_inference_steps).images[0]
  27. image.save(f"generated_image.png")
复制代码
在pipeline_onnx_stable_diffusion的基础上修改得到的直接调用onnx模型版本,可以用于其他推理引擎推理参考:
pipe_onnx_simple.py
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. #     http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import inspect
  16. from typing import Callable, List, Optional, Union
  17. import numpy as np
  18. import torch
  19. from transformers import CLIPTokenizer
  20. from diffusers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, DPMSolverMultistepScheduler
  21. from diffusers import AutoencoderKL, UNet2DConditionModel
  22. from transformers import CLIPTextModel, CLIPTokenizer
  23. from onnx_utils_simple import OnnxRuntimeModel, ORT_TO_NP_TYPE
  24. import logging as logger
  25. from tqdm.auto import tqdm
  26. from PIL import Image
  27. ort_device = "cpu" # gpu
  28. class OnnxStableDiffusionPipeline():
  29.     # vae_encoder: OnnxRuntimeModel
  30.     vae_decoder: OnnxRuntimeModel
  31.     text_encoder: OnnxRuntimeModel
  32.     tokenizer: CLIPTokenizer
  33.     unet: OnnxRuntimeModel
  34.     scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
  35.     def __init__(self, model_dir):
  36.         # scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
  37.         # stable-diffusion-v1-5 use PNDMScheduler by default
  38.         self.scheduler = PNDMScheduler.from_pretrained(os.path.join(model_dir, "scheduler/scheduler_config.json"))
  39.         # stable-diffusion-2-1 use DDIMScheduler by default
  40.         # self.scheduler = DDIMScheduler.from_pretrained(os.path.join(model_dir, "scheduler/scheduler_config.json"))
  41.         '''
  42.         self.scheduler = DPMSolverMultistepScheduler(
  43.             beta_start=0.00085,
  44.             beta_end=0.012,
  45.             beta_schedule="scaled_linear",
  46.             num_train_timesteps=1000,
  47.             trained_betas=None,
  48.             predict_epsilon=True,
  49.             thresholding=False,
  50.             algorithm_type="dpmsolver++",
  51.             solver_type="midpoint",
  52.             lower_order_final=True,
  53.         )
  54.         '''
  55.         # self.scheduler = EulerAncestralDiscreteScheduler.from_config(
  56.         # os.path.join(model_dir, "scheduler/scheduler_config.json"))
  57.         # self.tokenizer = BertTokenizer.from_pretrained(os.path.join(model_dir, "./tokenizer"))
  58.         self.tokenizer = CLIPTokenizer.from_pretrained(model_dir, subfolder="tokenizer")
  59.         self.text_encoder = OnnxRuntimeModel(os.path.join(model_dir, "text_encoder/model.onnx"), device=ort_device)
  60.         # in txt to image, vae_encoder is not necessary, only used in image to image generation
  61.         # self.vae_encoder = OnnxRuntimeModel(os.path.join(model_dir, "vae_encoder/model.onnx"))
  62.         self.vae_decoder = OnnxRuntimeModel(os.path.join(model_dir, "vae_decoder/model.onnx"), device=ort_device)
  63.         self.unet = OnnxRuntimeModel(os.path.join(model_dir, "unet/model.onnx"), device=ort_device)
  64.         self.safety_checker = None
  65.         self.requires_safety_checker = False
  66.         self.feature_extractor = False
  67.         self.progress_bar = tqdm
  68.         if hasattr(self.scheduler.config, "steps_offset") and self.scheduler.config.steps_offset != 1:
  69.             deprecation_message = (
  70.                 f"The configuration file of this scheduler: {self.scheduler} is outdated. `steps_offset`"
  71.                 f" should be set to 1 instead of {self.scheduler.config.steps_offset}. Please make sure "
  72.                 "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
  73.                 " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
  74.                 " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
  75.                 " file"
  76.             )
  77.             deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
  78.             new_config = dict(scheduler.config)
  79.             new_config["steps_offset"] = 1
  80.             scheduler._internal_dict = FrozenDict(new_config)
  81.         if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample is True:
  82.             deprecation_message = (
  83.                 f"The configuration file of this scheduler: {self.scheduler} has not set the configuration `clip_sample`."
  84.                 " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
  85.                 " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
  86.                 " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
  87.                 " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
  88.             )
  89.             deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
  90.             new_config = dict(self.scheduler.config)
  91.             new_config["clip_sample"] = False
  92.             self.scheduler._internal_dict = FrozenDict(new_config)
  93.         if self.safety_checker is None and self.requires_safety_checker:
  94.             logger.warning(
  95.                 f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
  96.                 " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
  97.                 " results in services or applications open to the public. Both the diffusers team and Hugging Face"
  98.                 " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
  99.                 " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
  100.                 " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
  101.             )
  102.         # if self.safety_checker is not None and self.feature_extractor is None:
  103.         #     raise ValueError(
  104.         #         "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
  105.         #         " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
  106.         #     )
  107.     def check_inputs(
  108.         self,
  109.         prompt: Union[str, List[str]],
  110.         height: Optional[int],
  111.         width: Optional[int],
  112.         callback_steps: int,
  113.         negative_prompt: Optional[str] = None,
  114.         prompt_embeds: Optional[np.ndarray] = None,
  115.         negative_prompt_embeds: Optional[np.ndarray] = None,
  116.     ):
  117.         if height % 8 != 0 or width % 8 != 0:
  118.             raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
  119.         if (callback_steps is None) or (
  120.             callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
  121.         ):
  122.             raise ValueError(
  123.                 f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
  124.                 f" {type(callback_steps)}."
  125.             )
  126.         if prompt is not None and prompt_embeds is not None:
  127.             raise ValueError(
  128.                 f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
  129.                 " only forward one of the two."
  130.             )
  131.         elif prompt is None and prompt_embeds is None:
  132.             raise ValueError(
  133.                 "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
  134.             )
  135.         elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
  136.             raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
  137.         if negative_prompt is not None and negative_prompt_embeds is not None:
  138.             raise ValueError(
  139.                 f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
  140.                 f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
  141.             )
  142.         if prompt_embeds is not None and negative_prompt_embeds is not None:
  143.             if prompt_embeds.shape != negative_prompt_embeds.shape:
  144.                 raise ValueError(
  145.                     "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
  146.                     f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
  147.                     f" {negative_prompt_embeds.shape}."
  148.                 )
  149.     def __call__(
  150.         self,
  151.         prompt: Union[str, List[str]] = None,
  152.         height: Optional[int] = 512,
  153.         width: Optional[int] = 512,
  154.         num_inference_steps: Optional[int] = 50,
  155.         guidance_scale: Optional[float] = 7.5,
  156.         negative_prompt: Optional[Union[str, List[str]]] = None,
  157.         num_images_per_prompt: Optional[int] = 1,
  158.         eta: Optional[float] = 0.0,
  159.         generator: Optional[np.random.RandomState] = None,
  160.         latents: Optional[np.ndarray] = None,
  161.         prompt_embeds: Optional[np.ndarray] = None,
  162.         negative_prompt_embeds: Optional[np.ndarray] = None,
  163.         output_type: Optional[str] = "pil",
  164.         return_dict: bool = True,
  165.         callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
  166.         callback_steps: int = 1,
  167.     ):
  168.         r"""
  169.         Function invoked when calling the pipeline for generation.
  170.         Args:
  171.             prompt (`str` or `List[str]`, *optional*):
  172.                 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
  173.                 instead.
  174.             image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
  175.                 `Image`, or tensor representing an image batch which will be upscaled. *
  176.             num_inference_steps (`int`, *optional*, defaults to 50):
  177.                 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
  178.                 expense of slower inference.
  179.             guidance_scale (`float`, *optional*, defaults to 7.5):
  180.                 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
  181.                 `guidance_scale` is defined as `w` of equation 2. of [Imagen
  182.                 Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
  183.                 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
  184.                 usually at the expense of lower image quality.
  185.             negative_prompt (`str` or `List[str]`, *optional*):
  186.                 The prompt or prompts not to guide the image generation. If not defined, one has to pass
  187.                 `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
  188.                 is less than `1`).
  189.             num_images_per_prompt (`int`, *optional*, defaults to 1):
  190.                 The number of images to generate per prompt.
  191.             eta (`float`, *optional*, defaults to 0.0):
  192.                 Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
  193.                 [`schedulers.DDIMScheduler`], will be ignored for others.
  194.             generator (`np.random.RandomState`, *optional*):
  195.                 One or a list of [numpy generator(s)](TODO) to make generation deterministic.
  196.             latents (`np.ndarray`, *optional*):
  197.                 Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
  198.                 generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
  199.                 tensor will ge generated by sampling using the supplied random `generator`.
  200.             prompt_embeds (`np.ndarray`, *optional*):
  201.                 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
  202.                 provided, text embeddings will be generated from `prompt` input argument.
  203.             negative_prompt_embeds (`np.ndarray`, *optional*):
  204.                 Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
  205.                 weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
  206.                 argument.
  207.             output_type (`str`, *optional*, defaults to `"pil"`):
  208.                 The output format of the generate image. Choose between
  209.                 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
  210.             return_dict (`bool`, *optional*, defaults to `True`):
  211.                 Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
  212.                 plain tuple.
  213.             callback (`Callable`, *optional*):
  214.                 A function that will be called every `callback_steps` steps during inference. The function will be
  215.                 called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
  216.             callback_steps (`int`, *optional*, defaults to 1):
  217.                 The frequency at which the `callback` function will be called. If not specified, the callback will be
  218.                 called at every step.
  219.         Returns:
  220.             [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
  221.             [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
  222.             When returning a tuple, the first element is a list with the generated images, and the second element is a
  223.             list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
  224.             (nsfw) content, according to the `safety_checker`.
  225.         """
  226.         # check inputs. Raise error if not correct
  227.         self.check_inputs(
  228.             prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
  229.         )
  230.         # define call parameters
  231.         if prompt is not None and isinstance(prompt, str):
  232.             batch_size = 1
  233.         elif prompt is not None and isinstance(prompt, list):
  234.             batch_size = len(prompt)
  235.         else:
  236.             batch_size = prompt_embeds.shape[0]
  237.         if generator is None:
  238.             generator = np.random
  239.         # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
  240.         # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
  241.         # corresponds to doing no classifier free guidance.
  242.         do_classifier_free_guidance = guidance_scale > 1.0
  243.         prompt_embeds = self._encode_prompt(
  244.             prompt,
  245.             num_images_per_prompt,
  246.             do_classifier_free_guidance,
  247.             negative_prompt,
  248.             prompt_embeds=prompt_embeds,
  249.             negative_prompt_embeds=negative_prompt_embeds,
  250.         )
  251.         # get the initial random noise unless the user supplied it
  252.         latents_dtype = prompt_embeds.dtype
  253.         latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
  254.         if latents is None:
  255.             latents = generator.randn(*latents_shape).astype(latents_dtype)
  256.         elif latents.shape != latents_shape:
  257.             raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
  258.         # set timesteps
  259.         self.scheduler.set_timesteps(num_inference_steps)
  260.         latents = latents * np.float64(self.scheduler.init_noise_sigma)
  261.         # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
  262.         # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
  263.         # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
  264.         # and should be between [0, 1]
  265.         accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
  266.         extra_step_kwargs = {}
  267.         if accepts_eta:
  268.             extra_step_kwargs["eta"] = eta
  269.         timestep_dtype = next(
  270.             (input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
  271.         )
  272.         timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
  273.         for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
  274.             # expand the latents if we are doing classifier free guidance
  275.             latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
  276.             latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
  277.             latent_model_input = latent_model_input.cpu().numpy()
  278.             # predict the noise residual
  279.             timestep = np.array([t], dtype=timestep_dtype)
  280.             noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
  281.             noise_pred = noise_pred[0]
  282.             # perform guidance
  283.             if do_classifier_free_guidance:
  284.                 noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
  285.                 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
  286.             # compute the previous noisy sample x_t -> x_t-1
  287.             scheduler_output = self.scheduler.step(
  288.                 torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
  289.             )
  290.             latents = scheduler_output.prev_sample.numpy()
  291.             # call the callback, if provided
  292.             if callback is not None and i % callback_steps == 0:
  293.                 callback(i, t, latents)
  294.         latents = 1 / 0.18215 * latents
  295.         # image = self.vae_decoder(latent_sample=latents)[0]
  296.         # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
  297.         image = np.concatenate(
  298.             [self.vae_decoder(latent_sample=latents[i: i + 1])[0] for i in range(latents.shape[0])]
  299.         )
  300.         image = np.clip(image / 2 + 0.5, 0, 1)
  301.         image = image.transpose((0, 2, 3, 1))
  302.         return image
  303.     def _encode_prompt(
  304.         self,
  305.         prompt: Union[str, List[str]],
  306.         num_images_per_prompt: Optional[int],
  307.         do_classifier_free_guidance: bool,
  308.         negative_prompt: Optional[str],
  309.         prompt_embeds: Optional[np.ndarray] = None,
  310.         negative_prompt_embeds: Optional[np.ndarray] = None,
  311.     ):
  312.         r"""
  313.         Encodes the prompt into text encoder hidden states.
  314.         Args:
  315.             prompt (`str` or `List[str]`):
  316.                 prompt to be encoded
  317.             num_images_per_prompt (`int`):
  318.                 number of images that should be generated per prompt
  319.             do_classifier_free_guidance (`bool`):
  320.                 whether to use classifier free guidance or not
  321.             negative_prompt (`str` or `List[str]`):
  322.                 The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
  323.                 if `guidance_scale` is less than `1`).
  324.             prompt_embeds (`np.ndarray`, *optional*):
  325.                 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
  326.                 provided, text embeddings will be generated from `prompt` input argument.
  327.             negative_prompt_embeds (`np.ndarray`, *optional*):
  328.                 Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
  329.                 weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
  330.                 argument.
  331.         """
  332.         if prompt is not None and isinstance(prompt, str):
  333.             batch_size = 1
  334.         elif prompt is not None and isinstance(prompt, list):
  335.             batch_size = len(prompt)
  336.         else:
  337.             batch_size = prompt_embeds.shape[0]
  338.         if prompt_embeds is None:
  339.             # get prompt text embeddings
  340.             text_inputs = self.tokenizer(
  341.                 prompt,
  342.                 padding="max_length",
  343.                 max_length=self.tokenizer.model_max_length,
  344.                 truncation=True,
  345.                 return_tensors="np",
  346.             )
  347.             text_input_ids = text_inputs.input_ids
  348.             untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
  349.             if not np.array_equal(text_input_ids, untruncated_ids):
  350.                 removed_text = self.tokenizer.batch_decode(
  351.                     untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
  352.                 )
  353.                 logger.warning(
  354.                     "The following part of your input was truncated because CLIP can only handle sequences up to"
  355.                     f" {self.tokenizer.model_max_length} tokens: {removed_text}"
  356.                 )
  357.             prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
  358.         prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
  359.         # get unconditional embeddings for classifier free guidance
  360.         if do_classifier_free_guidance and negative_prompt_embeds is None:
  361.             uncond_tokens: List[str]
  362.             if negative_prompt is None:
  363.                 uncond_tokens = [""] * batch_size
  364.             elif type(prompt) is not type(negative_prompt):
  365.                 raise TypeError(
  366.                     f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
  367.                     f" {type(prompt)}."
  368.                 )
  369.             elif isinstance(negative_prompt, str):
  370.                 uncond_tokens = [negative_prompt] * batch_size
  371.             elif batch_size != len(negative_prompt):
  372.                 raise ValueError(
  373.                     f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
  374.                     f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
  375.                     " the batch size of `prompt`."
  376.                 )
  377.             else:
  378.                 uncond_tokens = negative_prompt
  379.             max_length = prompt_embeds.shape[1]
  380.             uncond_input = self.tokenizer(
  381.                 uncond_tokens,
  382.                 padding="max_length",
  383.                 max_length=max_length,
  384.                 truncation=True,
  385.                 return_tensors="np",
  386.             )
  387.             negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
  388.         if do_classifier_free_guidance:
  389.             negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
  390.             # For classifier free guidance, we need to do two forward passes.
  391.             # Here we concatenate the unconditional and text embeddings into a single batch
  392.             # to avoid doing two forward passes
  393.             prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds])
  394.         return prompt_embeds
  395.     @staticmethod
  396.     def numpy_to_pil(images):
  397.         """
  398.         Convert a numpy image or a batch of images to a PIL image.
  399.         """
  400.         if images.ndim == 3:
  401.             images = images[None, ...]
  402.         images = (images * 255).round().astype("uint8")
  403.         if images.shape[-1] == 1:
  404.             # special case for grayscale (single channel) images
  405.             pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
  406.         else:
  407.             pil_images = [Image.fromarray(image) for image in images]
  408.         return pil_images
  409. model_dir = "/mnt/f/deep_learning/onnx_model/stable_diffusio_v1.5/"
  410. prompt = "a photo of an astronaut riding a horse on mars"
  411. num_inference_steps = 20
  412. onnx_pipe = OnnxStableDiffusionPipeline(model_dir)
  413. image = onnx_pipe(prompt, num_inference_steps=num_inference_steps)
  414. images = onnx_pipe.numpy_to_pil(image)
  415. for i, image in enumerate(images):
  416.     image.save(f"generated_image_{i}.png")
复制代码
onnx_utils_simple.py
  1. import logging as logger
  2. import numpy as np
  3. import os
  4. import onnxruntime as ort
  5. ORT_TO_NP_TYPE = {
  6.     "tensor(bool)": np.bool_,
  7.     "tensor(int8)": np.int8,
  8.     "tensor(uint8)": np.uint8,
  9.     "tensor(int16)": np.int16,
  10.     "tensor(uint16)": np.uint16,
  11.     "tensor(int32)": np.int32,
  12.     "tensor(uint32)": np.uint32,
  13.     "tensor(int64)": np.int64,
  14.     "tensor(uint64)": np.uint64,
  15.     "tensor(float16)": np.float16,
  16.     "tensor(float)": np.float32,
  17.     "tensor(double)": np.float64,
  18. }
  19. class OnnxRuntimeModel:
  20.     def __init__(self, model_path, device="cpu"):
  21.         self.model = None
  22.         providers = ["CPUExecutionProvider"]
  23.         if device == "gpu":
  24.             providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
  25.         if model_path:
  26.             self.load_model(model_path, providers)
  27.     def __call__(self, **kwargs):
  28.         inputs = {k: np.array(v) for k, v in kwargs.items()}
  29.         return self.model.run(None, inputs)
  30.     def load_model(self, path: str, providers=None, sess_options=None):
  31.         """
  32.         Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
  33.         Arguments:
  34.             path (`str` or `Path`):
  35.                 Directory from which to load
  36.             provider(`str`, *optional*):
  37.                 Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
  38.         """
  39.         if providers is None:
  40.             logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
  41.             providers = ["CPUExecutionProvider"]  # "CUDAExecutionProvider",
  42.         self.model = ort.InferenceSession(path, providers=providers, sess_options=sess_options)
复制代码
生成1张512x512图的shape信息
  1. txt encoder
  2. input_ids (1, 77)
  3. results shape: (1, 77, 768)
  4. results shape: (1, 768)
  5. unet
  6. sample (2, 4, 64, 64)
  7. timestep (1,)
  8. encoder_hidden_states (2, 77, 768)
  9. results shape: (2, 4, 64, 64)
  10. vae_decoder
  11. latent_sample (1, 4, 64, 64)
  12. results shape: (1, 3, 512, 512)
复制代码


来源:https://blog.csdn.net/u013701860/article/details/129772995
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则