Skip to content

Get latents path function improvement  #50

@fsbarros98

Description

@fsbarros98

Hi there, just started using this repository for some cool experiments.
I noticed some people are having trouble with the get_latents_path method in the run_tokenflow_pnp.py, I'm also one of them.
If I understood the intentions of that function correctly, I would suggest changing the first couple of lines from:

latents_path = os.path.join(config["latents_path"], f'sd_{config["sd_version"]}',
                     Path(config["data_path"]).stem, f'steps_{config["n_inversion_steps"]}')
latents_path = [x for x in glob.glob(f'{latents_path}/*') if '.' not in Path(x).name]
n_frames = [int([x for x in latents_path[i].split('/') if 'nframes' in x][0].split('_')[1]) for i in range(len(latents_path))]
latents_path = latents_path[np.argmax(n_frames)]

to

# Get parent folder of latents directories
latents_path_dir = os.path.join(config["latents_path"], 
                               f'sd_{config["sd_version"]}',
                               Path(config["data_path"]).stem, 
                               f'steps_{config["n_inversion_steps"]}')

# Get all possible folders that will contain latents according to different n_frames
latents_path_folders = [os.path.join(latents_path_dir, folder) 
                       for folder in os.listdir(latents_path_dir) 
                       if os.path.isdir(os.path.join(latents_path_dir, folder)) 
                       and 'nframes' in folder]

# Get all possible n_frames
n_frames = [int(latents_path_folder.split('_')[-1]) for latents_path_folder in latents_path_folders]

# Define latents_path according to the folder with the highest n_frames
latents_path = latents_path_folders[np.argmax(n_frames)]

By doing this you will avoid different OS collisions in the split('/') and is also more readable and easier to debug by not always using the same latents_path variable for different things. I would also suggest a more detailed description of what should be in the config.yml file since "data_path" for me was not obvious and if it comes from a folder that is generated in the preprocess it makes the process of detailing the config less automated, so I also changed the lines in get_data method so that I passed the video used for input in the config, and it extracts the name of the video (assuming it added a folder with that name...)

    def get_data(self):
        # load frames
        if self.config["data_path"].endswith('.mp4'):
            self.config["data_path"] = os.path.splitext(self.config["data_path"])[0]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions