-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[GLM-Image] New Models Support #12921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking quite good!
I think all the precomputations are in place and the use of caching also reads quite simple.
refactor attention processor to use dispatching function
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, looking great and super excited about this model
I left some comments, mostly, I'm a bit confused on the correct logic to set height/width
| prior_token_id, prior_token_image_ids, ar_height, ar_width = self.generate_prior_tokens( | ||
| prompt=prompt[0] if isinstance(prompt, list) else prompt, | ||
| image=image, | ||
| height=height, | ||
| width=width, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a few things here:
generate_prior_tokenswill error out ifheight = Noneandwidth = Nonear_height/ar_widthis pretty straightfoward to calculate, let's calculate them seperately for clarity- we can update generate_prior_tokens to only return two tokens, this way it is easier for user to skip this stage reusing pre-computed tokens
here is just a suggestion, I'm not completely ure the logic to assign defaut height/width are correct
| prior_token_id, prior_token_image_ids, ar_height, ar_width = self.generate_prior_tokens( | |
| prompt=prompt[0] if isinstance(prompt, list) else prompt, | |
| image=image, | |
| height=height, | |
| width=width, | |
| ) | |
| height = height or self.default_sample_size * self.vae_scale_factor | |
| width = width or self.default_sample_size * self.vae_scale_factor | |
| height = (height // 32) * 32 | |
| width = (width //32) * 32 | |
| prior_token_id, prior_token_image_ids = self.generate_prior_tokens( | |
| prompt=prompt[0] if isinstance(prompt, list) else prompt, | |
| image=image, | |
| height=height, | |
| width=width, | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me add a check to ensure that height and width cannot be None. This is a strict requirement, as these two parameters must be present for the AR model to correctly output tokens
| f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" | ||
| ) | ||
|
|
||
| if prompt is not None and prompt_embeds is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so, the current code structure won't work withprompt=None, when the prompt_embeds is passed, - we still need prompt generate tokens using the AR model
I think we'd need to accept both prior_token_id and prompt_embeds as inputs if prompt is None. so something like
if prompt is None:
if prior_token_id is None or prompt_embeds is None:
raise ValueError(
"When `prompt` is not provided, both `prior_token_id` and `prompt_embeds` must be passed."
)you also need to add the prior_token_id to pipeline input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prior_token_id implementation must be generated by AR so prompt must not be none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me change logic of this
| self.k_cache = k | ||
| self.v_cache = v | ||
| else: | ||
| self.k_cache = torch.cat([self.k_cache, k], dim=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Referring to L253, should dim be equal to 1 here?
| self.k_cache = torch.cat([self.k_cache, k], dim=2) | ||
| self.v_cache = torch.cat([self.v_cache, v], dim=2) | ||
|
|
||
| def get(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if it should be 1 or 2 but they should be same
probably better move the logic together like this so mistakes like https://github.com/huggingface/diffusers/pull/12921/files#r2678634789 is less likely to happen
| def get(self): | |
| def get(self, k: torch.Tensor, v: torch.Tensor): | |
| k_cache = torch.cat([self.k_cache, key], dim=2) | |
| v_cache = torch.cat([self.v_cache, key], dim=2) |
| k_cache, v_cache = kv_cache.get() | ||
| key = torch.cat([k_cache, key], dim=1) if k_cache is not None else key | ||
| value = torch.cat([v_cache, value], dim=1) if v_cache is not None else value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| k_cache, v_cache = kv_cache.get() | |
| key = torch.cat([k_cache, key], dim=1) if k_cache is not None else key | |
| value = torch.cat([v_cache, value], dim=1) if v_cache is not None else value | |
| key, value = kv_cache.get(key, value) if kv_cache is not None |
| num_images_per_prompt: int = 1, | ||
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | ||
| latents: Optional[torch.FloatTensor] = None, | ||
| prompt_embeds: Optional[torch.FloatTensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| prompt_embeds: Optional[torch.FloatTensor] = None, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_prompt_embeds: Optional[torch.Tensor] = None, | |
| prior_token_ids: Optional[torch.Tensor] = None, | |
| prior_image_token_ids: Optional[torch.Tensor] = None |
we should allow them to pre-compute the tokens since it is the most compute expensive part
we should allow them to pass pre-compute negative_prompt_embeds too because it is fixed
|
|
||
| device = self._execution_device | ||
|
|
||
| prior_token_id, prior_token_image_ids = self.generate_prior_tokens( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| prior_token_id, prior_token_image_ids = self.generate_prior_tokens( | |
| if prior_token_ids is None: | |
| prior_token_id, prior_token_image_ids = self.generate_prior_tokens( ...) |
@yiyixuxu @sayakpaul For check with model