o
    i(l                     @  s   d dl mZ d dlmZmZ d dlmZmZ d dlZ	d dl
Z
d dlZd dlZ	er.d dlmZ dd Zdd	 Zd
d Zdd Zdd ZG dd dZG dd dejZG dd dZG dd dejZG dd deZdd ZdS )    )annotations)TYPE_CHECKINGUnion)ioComfyExtensionN)UUIDc                 C  sR   t | tr| d ddd|f | d ddd|f fS | ddd|f dfS )zAExtract tensor from data, handling both single tensors and lists.r   N   )
isinstancelist)dataoutput_channels r   :/mnt/c/Users/fbmor/ComfyUI/comfy_extras/nodes_easycache.py_extract_tensor   s   
0r   c                 O  sp  |d }t |ts|d}|s|d }|d }t|d |j\}}|d }|d }|d ur:||r:| |i |S ||}	|}
d }||}|r	|| |	|}|j
r|r|jritd|j d	|  |||}|d ur|j||d
d}||gS |S |jr|d |_||}	d|_|	r	| r|j||dd|j    }| r	| r	|j| |j }| j|7  _|j|jk r|r|jrtd|j d|j  d
|_
|||}|d ur|j||d
d}||gS |S |jrtd|j d|j  d|_| |i |}t||j\}}|	ry| ry|j||dd|j    }|jrB||j }|j|   | rc|j| |j }|j!|   |jrctd|  |d urm|| |_|jrytd|  |"||
| |d ur|j"|||d
d |	r||
||_||||_|   |_|jrtd|jj#  |S )Ntransformer_options	easycacher   sigmasuuidsz6EasyCache [verbose] - was marked to skip this step by z. Present uuids: T)is_audioFclonez=EasyCache [verbose] - skipping step; cumulative_change_rate: , reuse_threshold: zAEasyCache [verbose] - NOT skipping step; cumulative_change_rate:         z1EasyCache [verbose] - approx_output_change_rate: z*EasyCache [verbose] - output_change_rate: z)EasyCache [verbose] - x_prev_subsampled: )$r	   dictgetr   r   is_past_end_timestephas_first_cond_uuidshould_do_easycachecheck_metadatacan_apply_cache_diffskip_current_stepverboselogginginfofirst_cond_uuidapply_cache_diffinitial_stephas_x_prev_subsampled	subsamplex_prev_subsampledflattenabsmeanhas_output_prev_norm has_relative_transformation_raterelative_transformation_rateoutput_prev_normcumulative_change_ratereuse_thresholdoutput_prev_subsampledoutput_change_ratesappenditemapprox_output_change_ratesupdate_cache_diffshape)executorargskwargsr   r   xaxr   r   r   next_x_previnput_changedo_easycacher!   resultresult_audioapprox_output_change_ratefull_outputoutputaudio_outputoutput_changeoutput_change_rater   r   r   easycache_forward_wrapper   s   








""




rL   c                 O  s@  |d }|d }|d d }| |r| |i |S |d d d d |jf }|}d }||}	|	r|| | r| rO|j|dd|j   	 }|
 r| r|j| |j }
| j|
7  _|j|jk r|jr|td|j d	|j  d
|_||S |jrtd|j d	|j  d|_| |i |}|
 r|j|dd|j   	 }|jr||j }|j|  | r|j| |j }
|j|
  |jrtd|
  |d ur|| |_|jrtd|  ||| |||_|||_|  	 |_|jrtd|jj  |S )Nr      r   r   r   Fr   z=LazyCache [verbose] - skipping step; cumulative_change_rate: r   TzALazyCache [verbose] - NOT skipping step; cumulative_change_rate: r   z1LazyCache [verbose] - approx_output_change_rate: z*LazyCache [verbose] - output_change_rate: z)LazyCache [verbose] - x_prev_subsampled: )r   r   r   r    r)   r*   r+   r,   r-   r.   r/   r0   r1   r2   r3   r4   r#   r$   r%   r"   r'   r5   r6   r7   r8   r9   r:   r;   )r<   r=   r>   timestepmodel_optionsr   r?   rA   rB   rC   rF   rH   rJ   rK   r   r   r   lazycache_predict_noise_wrapperh   s^   


 
 

rP   c                 O  s(   |d }|d d }d|_ | |i |S )Nr   r   r   F)r"   )r<   r=   r>   rO   r   r   r   r   !easycache_calc_cond_batch_wrapper   s   rQ   c           
      O  s,  z| j }|j}tj||_|jd d  |jjj|jd d< |jd d }t	
|j d|j d|j d|j  | |i |W |jd d }|j}|j}|jrtt	
|j dt| d|  t	
|j dt| d|  t|d	 d
 }z	|||j  }	W n ty   d}	Y nw t	
|j d|j d| d|	dd |  ||_S |jd d }|j}|j}|jrt	
|j dt| d|  t	
|j dt| d|  t|d	 d
 }z	|||j  }	W n ty   d}	Y nw t	
|j d|j d| d|	dd |  ||_w )z
    This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end.
    r   r   z enabled - threshold: z, start_percent: z, end_percent: z! [verbose] - output_change_rates z: z( [verbose] - approx_output_change_rates    r         ?z - skipped /z steps (z.2fzx speedup).)	class_objrO   comfymodel_patchercreate_model_options_cloner   prepare_timestepsmodelmodel_samplingr$   r%   namer4   start_percentend_percentr6   r9   r#   lentotal_steps_skippedZeroDivisionErrorreset)
r<   r=   r>   guiderorig_model_optionsr   r6   r9   total_stepsspeedupr   r   r   easycache_sample_wrapper   sL   ((  (  (rg   c                   @  s   e Zd Zd7d8ddZd9ddZd9ddZd:ddZd:ddZd:ddZd:ddZ	dd Z
d;d<d%d&Zd=d'd(Zd>d?d*d+Zd>d@d-d.Zd=d/d0ZdAd1d2Zd3d4 Zd5d6 ZdS )BEasyCacheHolderFNr4   floatr]   r^   subsample_factorintoffload_cache_diffboolr#   r   c                 C  s   d| _ || _|| _|| _|| _|| _|| _d| _d| _d | _	d| _
d| _d| _d | _d | _d | _d | _i | _i | _g | _g | _d| _d| _d| _d | _|| _d S )N	EasyCacher   TFr   )r\   r4   r]   r^   rj   rl   r#   start_tend_tr1   r3   r(   r"   r&   r+   r5   r2   uuid_cache_diffsuuid_cache_diffs_audior6   r9   r`   allow_mismatchcut_from_startstate_metadatar   selfr4   r]   r^   rj   rl   r#   r   r   r   r   __init__   s4   
zEasyCacheHolder.__init__rN   returnc                 C     |d | j k  S Nr   rp   r8   rw   rN   r   r   r   r         z$EasyCacheHolder.is_past_end_timestepc                 C     |d | j k S r{   ro   r8   r}   r   r   r   r         z#EasyCacheHolder.should_do_easycachec                 C  
   | j d uS Nr+   rw   r   r   r   r)         
z%EasyCacheHolder.has_x_prev_subsampledc                 C  r   r   r5   r   r   r   r   has_output_prev_subsampled   r   z*EasyCacheHolder.has_output_prev_subsampledc                 C  r   r   r2   r   r   r   r   r/      r   z$EasyCacheHolder.has_output_prev_normc                 C  r   r   r1   r   r   r   r   r0      r   z0EasyCacheHolder.has_relative_transformation_ratec                 C      | | j| _| | j| _| S r   percent_to_sigmar]   ro   r^   rp   rw   r[   r   r   r   rY         z!EasyCacheHolder.prepare_timestepsTr?   torch.Tensorr   
list[UUID]r   c                 C  s   |j d t| }|| j}| jdkr4||| |d | dd d | jd d | jf }|r2| S |S ||| |d | df }|rH| S |S )Nr   r   .)r;   r_   indexr&   rj   r   )rw   r?   r   r   batch_offsetuuid_idx	to_returnr   r   r   r*      s   
0zEasyCacheHolder.subsamplec                   s   t  fdd|D S )Nc                 3  s    | ]}| j v V  qd S r   )rq   ).0uuidr   r   r   	<genexpr>
  s    z7EasyCacheHolder.can_apply_cache_diff.<locals>.<genexpr>)allrw   r   r   r   r   r!   	  s   z$EasyCacheHolder.can_apply_cache_diffr   c                 C  sL  | j |v r|s|  jd7  _|r| jn| j}|jd t| }t|D ]\}}t|| |d | g}|jdd  || jdd  kr| jsVt	d| j| j d|j dg }	d}
t
|| j|jD ]+\}}|
rld}
qc||kr| jr~|	t|| d  qc|	td | qc|	td  qc||	 }|t|  || |j7  < q#|S )Nr   r   zCached dims  don't match x dims  - this is no goodTF)r&   r`   rr   rq   r;   r_   	enumerateslicers   
ValueErrorziprt   r7   tupletodevice)rw   r?   r   r   cache_diffsr   ir   batch_sliceslicingskip_this_dimdim_udim_xr   r   r   r'     s.     "z EasyCacheHolder.apply_cache_diffrH   c                 C  s  |r| j n| j}|jdd  |jdd  krb| js&td|j d|j dg }d}t|j|jD ]*\}}	|sR||	krR| jrI|t|	| d  n|td | n|td  d}q1|t	| }|| }
|
jd t
| }t|D ]\}}|
|| |d | df ||< qsd S )	Nr   zOutput dims r   r   TFr   .)rr   rq   r;   rs   r   r   rt   r7   r   r   r_   r   )rw   rH   r?   r   r   r   r   skip_dimdim_or   diffr   r   r   r   r   r   r:   )  s&   "z!EasyCacheHolder.update_cache_diffc                 C  s
   | j |v S r   )r&   r   r   r   r   r   @  r   z#EasyCacheHolder.has_first_cond_uuidc                 C  sX   |j |j|jdd  f}| jd u r|| _dS || jkrdS t| j d |   dS )Nr   T9 - Tensor shape, dtype or device changed, resetting stateFr   dtyper;   ru   r$   warnr\   rb   rw   r?   metadatar   r   r   r    C  s   

zEasyCacheHolder.check_metadatac                 C  sf   d| _ d| _d| _d| _g | _d | _| `d | _| `d | _| `d | _| `	i | _	| `
i | _
d| _d | _| S )Nr   TFr   )r1   r3   r(   r"   r6   r&   r+   r5   r2   rq   rr   r`   ru   r   r   r   r   rb   N  s&   zEasyCacheHolder.resetc              	   C  $   t | j| j| j| j| j| j| jdS N)r   )rh   r4   r]   r^   rj   rl   r#   r   r   r   r   r   r   c     $zEasyCacheHolder.cloneFNr4   ri   r]   ri   r^   ri   rj   rk   rl   rm   r#   rm   r   rk   rN   ri   ry   rm   ry   rm   T)r?   r   r   r   r   rm   ry   r   )r   r   ry   rm   )F)r?   r   r   r   r   rm   )rH   r   r?   r   r   r   r   rm   r?   r   ry   rm   )__name__
__module____qualname__rx   r   r   r)   r   r/   r0   rY   r*   r!   r'   r:   r   r    rb   r   r   r   r   r   rh      s"    
 







rh   c                   @  (   e Zd ZedddZedddZdS )EasyCacheNodery   	io.Schemac                 C     t jdddddt jjdddt jjdd	d
dddddt jjdd	ddddddt jjdd	ddddddt jjdddddgt jjddgdS )Nrn   z Native EasyCache implementation.advanced/debug/modelTrZ   zThe model to add EasyCache to.tooltipr4   r   皙?      @{Gz?'The threshold for reusing cached steps.mindefaultmaxstepr   advancedr]   333333?rS   z5The relative sampling step to begin use of EasyCache.r^   ffffff?z3The relative sampling step to end use of EasyCache.r#   F#Whether to log verbose information.r   r   r   zThe model with EasyCache.node_iddisplay_namedescriptioncategoryis_experimentalinputsoutputsr   SchemaModelInputFloatBooleanOutputclsr   r   r   define_schemah     zEasyCacheNode.define_schemarZ   io.Model.Typer4   ri   r]   r^   r#   rm   io.NodeOutputc              	   C  st   |  }t|||dd||jjjd|jd d< |tjj	j
dt |tjj	jdt |tjj	jdt t|S )N   Frj   rl   r#   r   r   r   )r   rh   rZ   latent_formatlatent_channelsrO   add_wrapper_with_keyrV   patcher_extension
WrappersMPOUTER_SAMPLErg   CALC_COND_BATCHrQ   DIFFUSION_MODELrL   r   
NodeOutputr   rZ   r4   r]   r^   r#   r   r   r   execute|  s   &
zEasyCacheNode.executeNry   r   rZ   r   r4   ri   r]   ri   r^   ri   r#   rm   ry   r   r   r   r   classmethodr   r   r   r   r   r   r   g  
    r   c                   @  s   e Zd Zd2d3ddZd4ddZd5ddZd5ddZd4ddZd4ddZd4ddZ	d4ddZ
dd  Zd6d7d%d&Zd8d'd(Zd9d*d+Zd:d,d-Zd.d/ Zd0d1 ZdS );LazyCacheHolderFNr4   ri   r]   r^   rj   rk   rl   rm   r#   r   c                 C  s   d| _ || _|| _|| _|| _|| _|| _d| _d| _d | _	d| _
d| _d | _d | _d | _d | _g | _g | _d| _d | _|| _d S )N	LazyCacher   Tr   )r\   r4   r]   r^   rj   rl   r#   ro   rp   r1   r3   r(   r+   r5   r2   
cache_diffr6   r9   r`   ru   r   rv   r   r   r   rx     s*   
zLazyCacheHolder.__init__ry   c                 C  r   r   r   r   r   r   r   has_cache_diff  r   zLazyCacheHolder.has_cache_diffrN   c                 C  rz   r{   r|   r}   r   r   r   r     r~   z$LazyCacheHolder.is_past_end_timestepc                 C  r   r{   r   r}   r   r   r   r     r   z#LazyCacheHolder.should_do_easycachec                 C  r   r   r   r   r   r   r   r)     r   z%LazyCacheHolder.has_x_prev_subsampledc                 C  r   r   r   r   r   r   r   r     r   z*LazyCacheHolder.has_output_prev_subsampledc                 C  r   r   r   r   r   r   r   r/     r   z$LazyCacheHolder.has_output_prev_normc                 C  r   r   r   r   r   r   r   r0     r   z0LazyCacheHolder.has_relative_transformation_ratec                 C  r   r   r   r   r   r   r   rY     r   z!LazyCacheHolder.prepare_timestepsTr?   r   r   c                 C  sH   | j dkr|dd d | j d d | j f }|r| S |S |r"| S |S )Nr   .)rj   r   )rw   r?   r   r   r   r   r   r*     s   
zLazyCacheHolder.subsamplec                 C  s    |  j d7  _ || j|j S )Nr   )r`   r   r   r   )rw   r?   r   r   r   r'     s   z LazyCacheHolder.apply_cache_diffrH   c                 C  s   || | _ d S r   r   )rw   rH   r?   r   r   r   r:     s   z!LazyCacheHolder.update_cache_diffc                 C  sP   |j |j|jf}| jd u r|| _dS || jkrdS t| j d |   dS )NTr   Fr   r   r   r   r   r      s   

zLazyCacheHolder.check_metadatac                 C  sV   d| _ d| _d| _g | _g | _| `d | _| `d | _| `d | _| `d | _d| _	d | _
| S )Nr   Tr   )r1   r3   r(   r6   r9   r   r+   r5   r2   r`   ru   r   r   r   r   rb     s    zLazyCacheHolder.resetc              	   C  r   r   )r   r4   r]   r^   rj   rl   r#   r   r   r   r   r   r     r   zLazyCacheHolder.cloner   r   r   r   r   )r?   r   r   rm   ry   r   )r?   r   )rH   r   r?   r   r   )r   r   r   rx   r   r   r   r)   r   r/   r0   rY   r*   r'   r:   r    rb   r   r   r   r   r   r     s     










r   c                   @  r   )LazyCacheNodery   r   c                 C  r   )Nr   zA homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.r   TrZ   zThe model to add LazyCache to.r   r4   r   r   r   r   r   r   r]   r   rS   z5The relative sampling step to begin use of LazyCache.r^   r   z3The relative sampling step to end use of LazyCache.r#   Fr   r   zThe model with LazyCache.r   r   r   r   r   r   r     r   zLazyCacheNode.define_schemarZ   r   r4   ri   r]   r^   r#   rm   r   c              	   C  s`   |  }t|||dd||jjjd|jd d< |tjj	j
dt |tjj	jdt t|S )Nr   Fr   r   r   	lazycache)r   r   rZ   r   r   rO   r   rV   r   r   r   rg   PREDICT_NOISErP   r   r   r   r   r   r   r     s
   &
zLazyCacheNode.executeNr   r   r   r   r   r   r   r    r   r  c                   @  s   e Zd ZdddZdS )EasyCacheExtensionry   list[type[io.ComfyNode]]c                   s
   t tgS r   )r   r  r   r   r   r   get_node_list  s   z EasyCacheExtension.get_node_listN)ry   r  )r   r   r   r  r   r   r   r   r  
  s    r  c                   C  s   t  S r   )r  r   r   r   r   comfy_entrypoint  s   r  )
__future__r   typingr   r   comfy_api.latestr   r   comfy.patcher_extensionrV   r$   torchcomfy.model_patcherr   r   r   rL   rP   rQ   rg   rh   	ComfyNoder   r   r  r  r  r   r   r   r   <module>   s*    T6 $f