-
Notifications
You must be signed in to change notification settings - Fork 139
Description
Hi there, just started using this repository for some cool experiments.
I noticed some people are having trouble with the get_latents_path method in the run_tokenflow_pnp.py, I'm also one of them.
If I understood the intentions of that function correctly, I would suggest changing the first couple of lines from:
latents_path = os.path.join(config["latents_path"], f'sd_{config["sd_version"]}',
Path(config["data_path"]).stem, f'steps_{config["n_inversion_steps"]}')
latents_path = [x for x in glob.glob(f'{latents_path}/*') if '.' not in Path(x).name]
n_frames = [int([x for x in latents_path[i].split('/') if 'nframes' in x][0].split('_')[1]) for i in range(len(latents_path))]
latents_path = latents_path[np.argmax(n_frames)]
to
# Get parent folder of latents directories
latents_path_dir = os.path.join(config["latents_path"],
f'sd_{config["sd_version"]}',
Path(config["data_path"]).stem,
f'steps_{config["n_inversion_steps"]}')
# Get all possible folders that will contain latents according to different n_frames
latents_path_folders = [os.path.join(latents_path_dir, folder)
for folder in os.listdir(latents_path_dir)
if os.path.isdir(os.path.join(latents_path_dir, folder))
and 'nframes' in folder]
# Get all possible n_frames
n_frames = [int(latents_path_folder.split('_')[-1]) for latents_path_folder in latents_path_folders]
# Define latents_path according to the folder with the highest n_frames
latents_path = latents_path_folders[np.argmax(n_frames)]
By doing this you will avoid different OS collisions in the split('/') and is also more readable and easier to debug by not always using the same latents_path variable for different things. I would also suggest a more detailed description of what should be in the config.yml file since "data_path" for me was not obvious and if it comes from a folder that is generated in the preprocess it makes the process of detailing the config less automated, so I also changed the lines in get_data method so that I passed the video used for input in the config, and it extracts the name of the video (assuming it added a folder with that name...)
def get_data(self):
# load frames
if self.config["data_path"].endswith('.mp4'):
self.config["data_path"] = os.path.splitext(self.config["data_path"])[0]