1313# limitations under the License.
1414
1515import inspect
16- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
16+ from typing import Any , Callable
1717
1818import numpy as np
1919import PIL
@@ -83,7 +83,7 @@ def retrieve_timesteps(
8383 scheduler ,
8484 num_inference_steps : int | None = None ,
8585 device : str | torch .device | None = None ,
86- timesteps : Optional [ List [ int ]] = None ,
86+ timesteps : list [ int ] | None = None ,
8787 sigmas : list [float ] | None = None ,
8888 ** kwargs ,
8989):
@@ -99,15 +99,15 @@ def retrieve_timesteps(
9999 must be `None`.
100100 device (`str` or `torch.device`, *optional*):
101101 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
102- timesteps (`List [int]`, *optional*):
102+ timesteps (`list [int]`, *optional*):
103103 Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
104104 `num_inference_steps` and `sigmas` must be `None`.
105- sigmas (`List [float]`, *optional*):
105+ sigmas (`list [float]`, *optional*):
106106 Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
107107 `num_inference_steps` and `timesteps` must be `None`.
108108
109109 Returns:
110- `Tuple [torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
110+ `tuple [torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
111111 second element is the number of inference steps.
112112 """
113113 if timesteps is not None and sigmas is not None :
@@ -208,11 +208,11 @@ def __init__(
208208 def _get_qwen3_prompt_embeds (
209209 text_encoder : Qwen3ForCausalLM ,
210210 tokenizer : Qwen2TokenizerFast ,
211- prompt : Union [ str , List [str ] ],
212- dtype : Optional [ torch .dtype ] = None ,
213- device : Optional [ torch .device ] = None ,
211+ prompt : str | list [str ],
212+ dtype : torch .dtype | None = None ,
213+ device : torch .device | None = None ,
214214 max_sequence_length : int = 512 ,
215- hidden_states_layers : List [int ] = (9 , 18 , 27 ),
215+ hidden_states_layers : list [int ] = (9 , 18 , 27 ),
216216 ):
217217 dtype = text_encoder .dtype if dtype is None else dtype
218218 device = text_encoder .device if device is None else device
@@ -317,7 +317,7 @@ def _prepare_latent_ids(
317317 @staticmethod
318318 # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
319319 def _prepare_image_ids (
320- image_latents : List [torch .Tensor ], # [(1, C, H, W), (1, C, H, W), ...]
320+ image_latents : list [torch .Tensor ], # [(1, C, H, W), (1, C, H, W), ...]
321321 scale : int = 10 ,
322322 ):
323323 r"""
@@ -327,7 +327,7 @@ def _prepare_image_ids(
327327 dimensions.
328328
329329 Args:
330- image_latents (List [torch.Tensor]):
330+ image_latents (list [torch.Tensor]):
331331 A list of image latent feature tensors, typically of shape (C, H, W).
332332 scale (int, optional):
333333 A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
@@ -424,12 +424,12 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch
424424
425425 def encode_prompt (
426426 self ,
427- prompt : Union [ str , List [str ] ],
428- device : Optional [ torch .device ] = None ,
427+ prompt : str | list [str ],
428+ device : torch .device | None = None ,
429429 num_images_per_prompt : int = 1 ,
430430 prompt_embeds : torch .Tensor | None = None ,
431431 max_sequence_length : int = 512 ,
432- text_encoder_out_layers : Tuple [int ] = (9 , 18 , 27 ),
432+ text_encoder_out_layers : tuple [int ] = (9 , 18 , 27 ),
433433 ):
434434 device = device or self ._execution_device
435435
@@ -507,7 +507,7 @@ def prepare_latents(
507507 # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
508508 def prepare_image_latents (
509509 self ,
510- images : List [torch .Tensor ],
510+ images : list [torch .Tensor ],
511511 batch_size ,
512512 generator : torch .Generator ,
513513 device ,
@@ -608,25 +608,25 @@ def interrupt(self):
608608 @replace_example_docstring (EXAMPLE_DOC_STRING )
609609 def __call__ (
610610 self ,
611- image : Optional [ Union [ List [ PIL .Image .Image ], PIL .Image .Image ]] = None ,
612- prompt : Union [ str , List [str ] ] = None ,
611+ image : list [ PIL .Image .Image ] | PIL .Image .Image | None = None ,
612+ prompt : str | list [str ] = None ,
613613 height : int | None = None ,
614614 width : int | None = None ,
615615 num_inference_steps : int = 50 ,
616616 sigmas : list [float ] | None = None ,
617- guidance_scale : Optional [ float ] = 4.0 ,
617+ guidance_scale : float = 4.0 ,
618618 num_images_per_prompt : int = 1 ,
619619 generator : torch .Generator | list [torch .Generator ] | None = None ,
620620 latents : torch .Tensor | None = None ,
621621 prompt_embeds : torch .Tensor | None = None ,
622- negative_prompt_embeds : Optional [ Union [ str , List [str ]]] = None ,
622+ negative_prompt_embeds : str | list [str ] | None = None ,
623623 output_type : str = "pil" ,
624624 return_dict : bool = True ,
625- attention_kwargs : Optional [ Dict [ str , Any ]] = None ,
626- callback_on_step_end : Optional [ Callable [[int , int , Dict ], None ]] = None ,
627- callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
625+ attention_kwargs : dict [ str , Any ] | None = None ,
626+ callback_on_step_end : Callable [[int , int , dict ], None ] | None = None ,
627+ callback_on_step_end_tensor_inputs : list [str ] = ["latents" ],
628628 max_sequence_length : int = 512 ,
629- text_encoder_out_layers : Tuple [int ] = (9 , 18 , 27 ),
629+ text_encoder_out_layers : tuple [int ] = (9 , 18 , 27 ),
630630 ):
631631 r"""
632632 Function invoked when calling the pipeline for generation.
@@ -693,7 +693,7 @@ def __call__(
693693 will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
694694 `._callback_tensor_inputs` attribute of your pipeline class.
695695 max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
696- text_encoder_out_layers (`Tuple [int]`):
696+ text_encoder_out_layers (`tuple [int]`):
697697 Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
698698
699699 Examples:
0 commit comments