o
    i^                     @  sb  d dl mZ d dlmZmZ d dlZd dlZd dlZd dl	m
Z
 d dlmZmZ d dlZd dlZd dlZerHd dlmZ d dlmZ d dlmZ G d	d
 d
eZG dd deZG dd deZG dd dZdd g fdpddZe
G dd dZe
G dd  d Zed!g d"Z G d#d$ d$eZ!dqd&d'Z"drd*d+Z#d,d- Z$drd.d/Z%dsdtd4d5Z&dud6d7Z'G d8d9 d9Z(dvd>d?Z)dvd@dAZ*dvdBdCZ+dvdDdEZ,dwdFdGZ-e(j.e)e(j/e*e(j0e+e(j1e,iZ2dxdJdKZ3dsdydPdQZ4dzdRdSZ5dzdTdUZ6d{dVdWZ7G dXdY dYZ8e8j9e5e8j:e6e8j;e6e8j<e7iZ=d|d[d\Z>d]d^ Z?d}dadbZ@d~dddeZAddfdgZBddhdiZCddndoZDdS )    )annotations)TYPE_CHECKINGCallableN)	dataclass)ABCabstractmethod)	BaseModel)ModelPatcher)ControlBasec                   @  s0   e Zd Zdd ZedddZedd	d
ZdS )ContextWindowABCc                 C     d S N selfr   r   3/mnt/c/Users/fbmor/ComfyUI/comfy/context_windows.py__init__      zContextWindowABC.__init__fulltorch.Tensorreturnc                 C     t d)z@
        Get torch.Tensor applicable to current window.
        Not implemented.NotImplementedError)r   r   r   r   r   
get_tensor      zContextWindowABC.get_tensorto_addc                 C  r   )z
        Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
        r   r   )r   r   r   r   r   r   
add_window   r   zContextWindowABC.add_windowNr   r   r   r   r   r   r   r   r   r   )__name__
__module____qualname__r   r   r   r   r   r   r   r   r      s    r   c                   @  s@   e Zd Zdd ZedddZeddddZedddZdS ) ContextHandlerABCc                 C  r   r   r   r   r   r   r   r   $   r   zContextHandlerABC.__init__modelr   condslist[list[dict]]x_inr   timestepmodel_options	dict[str]r   boolc                 C  r   Nr   r   r   r%   r&   r(   r)   r*   r   r   r   should_use_context'      z$ContextHandlerABC.should_use_contextNcond_in
list[dict]windowr   listc                 C  r   r-   r   )r   r1   r(   r3   devicer   r   r   get_resized_cond+   r0   z"ContextHandlerABC.get_resized_condcalc_cond_batchr   c                 C  r   r-   r   )r   r7   r%   r&   r(   r)   r*   r   r   r   execute/   r0   zContextHandlerABC.executer%   r   r&   r'   r(   r   r)   r   r*   r+   r   r,   r   )r1   r2   r(   r   r3   r   r   r4   r7   r   r%   r   r&   r'   r(   r   r)   r   r*   r+   )r!   r"   r#   r   r   r/   r6   r8   r   r   r   r   r$   #   s    r$   c                   @  s@   e Zd ZddddZd	d	g fdddZddddZdddZd	S )IndexListContextWindowr   
index_list	list[int]diminttotal_framesc                 C  s:   || _ t|| _|| _|| _t|t| d|  | _d S )N   )r<   lencontext_lengthr>   r@   minmaxcenter_ratio)r   r<   r>   r@   r   r   r   r   6   s
   
zIndexListContextWindow.__init__Nr   r   r   c                 C  s|   |d u r| j }|dkr|j| dkr|S ttd g| | jg }|| }|r9ttd g| |g }|| ||< ||S Nr      )r>   shapetupleslicer<   to)r   r   r5   r>   retain_index_listidxr3   r   r   r   r   =   s   
z!IndexListContextWindow.get_tensorr   c                 C  s<   |d u r| j }ttd g| | jg }||  |7  < |S r   )r>   rJ   rK   r<   )r   r   r   r>   rN   r   r   r   r   I   s
   z!IndexListContextWindow.add_windownum_regionsc                 C  s"   t | j| }tt|d|d S rG   )r?   rF   rD   rE   )r   rO   
region_idxr   r   r   get_region_indexP   s   z'IndexListContextWindow.get_region_index)r   r   )r<   r=   r>   r?   r@   r?   r   r   r    )rO   r?   r   r?   )r!   r"   r#   r   r   r   rQ   r   r   r   r   r;   5   s
    r;   c                   @  s(   e Zd ZdZdZdZdZdZdd ZdS )	IndexListCallbacksevaluate_context_windowscombine_context_window_resultsexecute_startexecute_cleanupresize_cond_itemc                 C  s   i S r   r   r   r   r   r   init_callbacks\   r   z!IndexListCallbacks.init_callbacksN)	r!   r"   r#   EVALUATE_CONTEXT_WINDOWSCOMBINE_CONTEXT_WINDOW_RESULTSEXECUTE_STARTEXECUTE_CLEANUPRESIZE_COND_ITEMrX   r   r   r   r   rR   U   s    rR   rH   r3   r(   r   temporal_dimr?   temporal_scaletemporal_offsetrM   r=   c                   sN  t | drt| jtjsd S | j}||jkrd S ||}	|dkr/||j  }
|	|
kr/d S  dkrE|dkrE|j||||d}| 	|S  dkr_ fdd|j
 d  D }dd |D }nt|j
}|shd S |dkrg }|D ]}t|D ]}|| | }||	k r|| qvqp|}|sd S ttd g| |g }|| |}| 	|S )NcondrH   r   r>   rM   c                   s   g | ]}|  qS r   r   .0ir`   r   r   
<listcomp>t       zslice_cond.<locals>.<listcomp>c                 S  s   g | ]}d |kr|qS )r   r   rc   r   r   r   rg   u       )hasattr
isinstancera   torchTensorndimsizer>   r   
_copy_withr<   r4   rangeappendrJ   rK   rL   )
cond_valuer3   r(   r5   r^   r_   r`   rM   cond_tensor	cond_sizeexpected_sizeslicedindicesscaledre   ksirN   r   rf   r   
slice_cond`   sD   





r|   c                   @     e Zd ZU ded< ded< dS )ContextSchedulestrnamer   funcNr!   r"   r#   __annotations__r   r   r   r   r~         
 r~   c                   @  r}   )ContextFuseMethodr   r   r   r   Nr   r   r   r   r   r      r   r   ContextResults)
window_idxsub_conds_out	sub_condsr3   c                   @  s   e Zd Zddddddg dfdBddZdCdd ZdDdEd$d%ZdDdFd+d,ZdGd-d.ZdHd0d1ZdId4d5Z		!dJdKd8d9Z
dLd@dAZd!S )MIndexListContextHandlerrH   r   Fcontext_scheduler~   fuse_methodr   rC   r?   context_overlapcontext_strideclosed_loopr,   r>   	freenoisecond_retain_index_listr=   split_conds_to_windowsc                 C  sd   || _ || _|| _|| _|| _|| _|| _d| _|| _|	r'dd |		dD ng | _
|
| _i | _d S )Nr   c                 S  s   g | ]}t | qS r   )r?   strip)rd   xr   r   r   rg      ri   z4IndexListContextHandler.__init__.<locals>.<listcomp>,)r   r   rC   r   r   r   r>   _stepr   splitr   r   	callbacks)r   r   r   rC   r   r   r   r>   r   r   r   r   r   r   r      s   
z IndexListContextHandler.__init__r%   r   r&   r'   r(   r   r)   r*   r+   r   c              
   C  s\   | | j| jkr,td| j d| j d| | j d | jr*td| j  dS dS )NzUsing context windows z with overlap z for z frames.z%Retaining original cond for indexes: TF)ro   r>   rC   logginginfor   r   r.   r   r   r   r/      s   *z*IndexListContextHandler.should_use_contextNcontrolr
   c                 C  s   |j d ur| |j | |S r   )previous_controlnetprepare_control_objects)r   r   r5   r   r   r   r      s   
z/IndexListContextHandler.prepare_control_objectsr1   r2   r3   r;   r4   c                 C  s  |d u rd S g }| j r6t|dkr6|t|}td| d|jd  d|jd  d|jd || g}|D ]}| }|D ]}	z||	 }
t|
t	j
ry| j|
jk rp|
| j|| jkrp||
}||||	< nz|
|||	< nq|	d	kr| |
|||	< nct|
tr|
 }| D ]J\}}d
}tjtj| jD ]}|||||||}|d ur|||< d} nq|s| jd ur| jj|||||| jd}|d ur|||< d}|rqt|t	j
r| j|jk r|| j|| jks|j| jk r|d|| jkr|||||< q|dkrJt|drJt|jt	j
rJ|j}|jdkrI|d|| jkrI||j||dd||< q|dkrt|drt|jt	j
r|j}|jdkr|d|| jkr|j||d| jd}||||< qt|drt|jt	j
r| j|jjk r|j| j|| jks|jj| jk r|jd|| jkr||j|j|| jd||< q|dkr||j||< |j|| _q|||	< n|
||	< W ~
qA~
w | | q8|S )NrH   z)Splitting conds to windows; using region z for window r   -z with center ratio z.3fr   FT)rM   audio_embedra   r>   vace_context      rb   num_video_frames)!r   rB   rQ   r   r   r<   rF   copyrk   rl   rm   r>   rn   ro   r   rL   r   dictitemscomfypatcher_extensionget_all_callbacksrR   r]   r   _modelresize_cond_for_context_windowr   rj   ra   rp   rC   rr   )r   r1   r(   r3   r5   resized_condregionactual_condresized_actual_condkey	cond_itemactual_cond_itemnew_cond_itemcond_keyrs   handledcallbackresult
audio_cond	vace_condsliced_vacer   r   r   r6      s   2


$

&&&$&$,*

z(IndexListContextHandler.get_resized_condc                 C  sN   t j|d d |d dd}t |}t |dkrd S t|d  | _d S )Ntransformer_optionssample_sigmasr   g-C6?)rtol)rl   isclosenonzeronumelr?   itemr   )r   r)   r*   maskmatchesr   r   r   set_step  s
   
z IndexListContextHandler.set_steplist[IndexListContextWindow]c                   s4   | j j |} fdd|D }|S )Nc                   s   g | ]
}t |j d qS ))r>   r@   )r;   r>   )rd   r3   full_lengthr   r   r   rg         z?IndexListContextHandler.get_context_windows.<locals>.<listcomp>)ro   r>   r   r   )r   r%   r(   r*   context_windowsr   r   r   get_context_windows  s   z+IndexListContextHandler.get_context_windowsr7   r   c                   s  | _  ||  ||}tt|}fdd|D }	 jjtjkr1 fdd|D }
n
 fdd|D }
 fdd|D }t	j
tj jD ]}| |||| qN|D ]'} |||||g|}|D ]} |j|j|j|jt|||	|
|
 qlq\zL jjtjkr~
|	W t	j
tj jD ]}| |||| qS tt|	D ]}|	|  |
|   < q~
|	W t	j
tj jD ]}| |||| qS t	j
tj jD ]}| |||| qw )Nc                   s   g | ]}t  qS r   )rl   
zeros_likerd   _)r(   r   r   rg     s    z3IndexListContextHandler.execute.<locals>.<listcomp>c                   $   g | ]}t jt jjd qS r5   )rl   onesget_shape_for_dimr>   r5   r   r   r(   r   r   rg        $ c                   r   r   )rl   zerosr   r>   r5   r   r   r   r   rg     r   c                   s   g | ]}d gj  j  qS )g        )rI   r>   r   r   r   r   rg      s    )r   r   r   r4   	enumerater   r   ContextFuseMethodsRELATIVEr   r   r   rR   r[   r   rS   rT   r   r   r3   r   rB   r\   rq   )r   r7   r%   r&   r(   r)   r*   r   enumerated_context_windowsconds_finalcounts_finalbiases_finalr   enum_windowresultsr   re   r   r   r   r8     sD   zIndexListContextHandler.executer   (list[tuple[int, IndexListContextWindow]]c
                   s   g }
|D ]l\}t j  t jtjjD ]}||||||| |	 q|d d<  }j| dd} fdd|D }||||||} d urft	t
|D ]}|| j||< qY|
t||| q|
S )Nr   context_windowr   r   c                   s   g | ]
} | qS r   )r6   )rd   ra   r5   r   r3   r(   r   r   rg   I  r   zDIndexListContextHandler.evaluate_context_windows.<locals>.<listcomp>)r   model_management)throw_exception_if_processing_interruptedr   r   rR   rY   r   r   rq   rB   rL   r5   rr   r   )r   r7   r%   r(   r&   r)   r   r*   r5   first_devicer   r   r   sub_xsub_timestepr   r   re   r   r   r   rS   :  s   
z0IndexListContextHandler.evaluate_context_windowsr   total_windowsr   list[torch.Tensor]r   r   c                 C  s  | j jtjkrt|jD ]z\}}dt||jd |jd  d  |jd |jd  d d   }td|}tt	|D ]J}|
| | }|||  }|||  }t
td g| j |g }t
td g| j |g }|| | | || | |  || |< || |
| |< q;qn6t|j|j| j |j| |d}t||| j|jd}tt	|D ]}||| || |  ||	| | qtjtj| jD ]}|| |||||||||	|
 qd S )NrH   r   r   rA   g{Gz?)sigmar   )r   r   r   r   r   r<   absrE   rq   rB   rJ   rK   r>   get_context_weightsrC   rI   match_weights_to_dimr5   r   r   r   r   rR   rZ   r   )r   r(   r   r   r3   r   r   r)   r   r   r   posrN   biasre   
bias_totalprev_weight
new_weight
idx_window
pos_windowweightsweights_tensorr   r   r   r   rT   S  s,   @
(z6IndexListContextHandler.combine_context_window_results)r   r~   r   r   rC   r?   r   r?   r   r?   r   r,   r>   r?   r   r,   r   r=   r   r,   r9   r   )r   r
   r   r
   )r1   r2   r(   r   r3   r;   r   r4   )r)   r   r*   r+   )r%   r   r(   r   r*   r+   r   r   r:   )NN)
r7   r   r%   r   r(   r   r)   r   r   r   )r(   r   r3   r;   r   r?   r   r?   r)   r   r   r   r   r   r   r   )r!   r"   r#   r   r/   r   r6   r   r   r8   rS   rT   r   r   r   r   r      s    

	
R

&r   noise_shapec                 O  sh   | dd }|d u rtd| dd }|d ur(t|}t||j |j||j< | ||g|R i |S )Nr*   zdmodel_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.context_handler)get	Exceptionr4   rD   r>   rC   )executorr%   r   argskwargsr*   handlerr   r   r   _prepare_sampling_wrapperq  s   r   r%   r	   c                 C     |  tjjjdt d S )NContextWindows_prepare_sampling)add_wrapper_with_keyr   r   
WrappersMPPREPARE_SAMPLINGr   r%   r   r   r   create_prepare_sampling_wrapper}  
   r  c           
      O  s   | dd }|d u rtd| dd }	|	d u rtd|	js.| |||||g|R i |S t||	j|	j|	j|d }| |||||g|R i |S )Nr*   zbmodel_options not found in sampler_sample_wrapper; this should never happen, something went wrong.r   zdcontext_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.seed)r   r   r   apply_freenoiser>   rC   r   )
r   guidersigmas
extra_argsr   noiser   r   r*   r   r   r   r   _sampler_sample_wrapper  s   r  c                 C  r  )NContextWindows_sampler_sample)r  r   r   r  SAMPLER_SAMPLEr  r  r   r   r   create_sampler_sample_wrapper  r  r  r   list[float]r>   r   c                 C  sX   t |j}t| j|d}t|D ]}|d}qt|| d D ]}|d}q"|S )Nr   r   rH   r   )rB   rI   rl   rm   rL   rq   	unsqueeze)r   r(   r>   r5   
total_dimsr   r   r   r   r   r     s   
r   c                 C  sZ   t | j}g }t|D ]}|d q|| j|  t|| d D ]}|d q#|S )NrH   )rB   rI   rq   rr   )r(   r>   r  rI   r   r   r   r   r     s   
r   c                   @  s   e Zd ZdZdZdZdZdS )ContextScheduleslooped_uniformstandard_uniformstandard_staticbatchedN)r!   r"   r#   UNIFORM_LOOPEDUNIFORM_STANDARDSTATIC_STANDARDBATCHEDr   r   r   r   r    s
    r  
num_framesr   r*   r+   c              
     s   g } |j k r|tt  |S t|jttt	 |j  d }dt
|> D ]E}tt t|j }ttt|j| |  | |jrLdn|j  |j | |j D ]}| fddt|||j |  |D  qYq,|S )NrH   r   c                      g | ]}|  qS r   r   rd   er  r   r   rg     rh   z1create_windows_uniform_looped.<locals>.<listcomp>)rC   rr   r4   rq   rD   r   r?   npceillog2arangeroundordered_halvingr   r   r   )r  r   r*   windowsr   context_steppadjr   r#  r   create_windows_uniform_looped  s   
&,r.  c              
     s  g } |j kr|tt  |S t|jttt	 |j  d }dt
|> D ]@}tt t|j }ttt|j| |  | |j  |j | |j D ]}| fddt|||j |  |D  qTq,g }d}	|	t|k rt||	  \}
}|
r||	 | }t||	  d |||	d t|  vr||	d tt|||j   td|	D ]}||	 || kr||	  nq|	d7 }	|	t|k sw|  |D ]}|| q|S )NrH   c                   r   r   r   r!  r#  r   r   rg     rh   z3create_windows_uniform_standard.<locals>.<listcomp>r   r#  )rC   rr   r4   rq   rD   r   r?   r$  r%  r&  r'  r(  r)  r   r   rB   does_window_roll_overshift_window_to_endinsertreversepop)r  r   r*   r*  r   r+  r,  r-  delete_idxswin_iis_rollroll_idxroll_valpre_ire   r   r#  r   create_windows_uniform_standard  sD   
&, 
r:  c           	      C  s   g }| |j kr|tt|  |S |j |j }td| |D ]0}||j  }|| krA||  }|| }|tt|||j    |S |tt|||j   q|S Nr   )rC   rr   r4   rq   r   )	r  r   r*   r*  delta	start_idxendingfinal_deltafinal_start_idxr   r   r   create_windows_static_standard  s   

rA  c              	   C  s\   g }| |j kr|tt|  |S td| |j D ]}|tt|t||j  |  q|S r;  )rC   rr   r4   rq   rD   )r  r   r*   r*  r=  r   r   r   create_windows_batched  s   
"rB  c                 C  s   t t| gS r   r4   rq   )r  r   r   r   r   create_windows_default  s   rD  r   r   c                 C  .   t | d }|d u rtd|  dt| |S )NzUnknown context_schedule ''.)CONTEXT_MAPPINGr   
ValueErrorr~   )r   r   r   r   r   get_matching_context_schedule%     
rI  lengthr   idxsr   c                 C  s   |j j| ||||dS )N)r   r   r   rL  r   r   )rK  r   rL  r   r   r   r   r   r   ,  s   r   c                 K  s
   dg|  S )Ng      ?r   )rK  r   r   r   r   create_weights_flat0  s   
rN  c                 K  sv   | d dkr| d }t td|d dt t|dd }|S | d d }t td|d|g t t|d dd }|S )NrA   r   rH   r   rC  )rK  r   
max_weightweight_sequencer   r   r   create_weights_pyramid4  s   $*rQ  c                 K  sh   t | }t|dkrt dd|j}||d |j< t||d k r2t dd|j}|||j d < |S )Nr   gBA8rH   )rl   r   rD   linspacer   rE   )rK  r   rL  r   r   weights_torchramp_up	ramp_downr   r   r   create_weights_overlap_linear?  s   
rV  c                   @  s2   e Zd ZdZdZdZdZeeegZeeeegZdS )r   flatpyramidrelativezoverlap-linearN)	r!   r"   r#   FLATPYRAMIDr   OVERLAP_LINEARLISTLIST_STATICr   r   r   r   r   M  s    
r   r   c                 C  rE  )NzUnknown fuse_method 'rF  )FUSE_MAPPINGr   rH  r   rM  r   r   r   get_matching_fuse_method^  rJ  r`  c                 C  s(   | d}|d d d }t |d}|d S )N064br   rA   l            )r?   )valbin_strbin_flipas_intr   r   r   r)  e  s   
r)  r*  list[list[int]]c              
   C  sD   t t|}| D ]}|D ]}z|| W q ty   Y qw q|S r   )r4   rq   removerH  )r*  r  all_indexeswrb  r   r   r   get_missing_indexesq  s   rj  tuple[bool, int]c                 C  s:   d}t | D ]\}}|| }||k rd|f  S |}qdS )Nr   T)Fr   )r   )r3   r  prev_valre   rb  r   r   r   r/  |  s   r/  c                 C  s6   | d }t t| D ]}| | | | | | |< q
d S r;  )rq   rB   )r3   r  	start_valre   r   r   r   shift_window_to_start  s   rn  c                 C  sD   t | | | d }|| d }tt| D ]
}| | | | |< qd S )Nr   rH   )rn  rq   rB   )r3   r  end_val	end_deltare   r   r   r   r0    s   
r0  r  rC   r   r	  c                 C  s   t d tjdd|}| j| }|| }td|| |D ]E}|| }	t|||	 }
|
dkr3 | S tj|
|dd| }t	d g| j
 }|||< t	d g| j
 }t	|	|	|
 ||< | t| | t|< q| S )Nz#Context windows: Applying FreeNoisecpur   r   )	generatorr5   )r   r   rl   	Generatormanual_seedrI   rq   rD   randpermrK   rn   rJ   )r  r>   rC   r   r	  rr  latent_video_lengthr<  r=  	place_idxactual_deltalist_idxsource_slicetarget_slicer   r   r   r
    s"   

r
  )r3   r;   r(   r   r^   r?   r_   r?   r`   r?   rM   r=   )r   r   )r%   r	   r   )r   r  r(   r   r>   r?   r   r   )r(   r   r>   r?   r   r=   )r  r?   r   r   r*   r+   )r  r?   r   r   )r   r   r   r~   )
rK  r?   r   r?   rL  r=   r   r   r   r   )rK  r?   r   r  )rK  r?   r   r?   rL  r=   r   r   )r   r   r   r   )r*  rf  r  r?   r   r=   )r3   r=   r  r?   r   rk  )r3   r=   r  r?   )
r  r   r>   r?   rC   r?   r   r?   r	  r?   )E
__future__r   typingr   r   rl   numpyr$  collectionsdataclassesr   abcr   r   r   comfy.model_managementr   comfy.patcher_extensioncomfy.model_baser   comfy.model_patcherr	   comfy.controlnetr
   r   r$   r;   rR   r|   r~   r   
namedtupler   r   r   r  r  r  r   r   r  r.  r:  rA  rB  rD  r  r  r  r  rG  rI  r   rN  rQ  rV  r   rZ  r[  r   r\  r_  r`  r)  rj  r/  rn  r0  r
  r   r   r   r   <module>   sz     , 
[


	



0











