Skip to content

TypeError: _extract_past_from_model_output() got an unexpected keyword argument 'standardize_cache_format' #204

@sjghh

Description

@sjghh

System Info / 系統信息

transformers 4.47.0

Who can help? / 谁可以帮助到您?

No response

Information / 问题信息

  • The official example scripts / 官方的示例脚本
  • My own modified scripts / 我自己修改的脚本和任务

Reproduction / 复现过程

video_demo/inference.py

import io
import os
import numpy as np
import torch
from decord import cpu, VideoReader, bridge
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "/data/benchmark_weight/cogvlm2-video-llama3-chat"
VIDEO_PATH = "/data/video-llama2/VideoLLaMA2-main/datasets/custom_sft/videos/MER2023/sample_00000003.mp4"
PROMPT = "Could you describe the features of the individual in the video? "

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16

Disable tokenizer parallelism warning

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def load_video(video_path, strategy='chat'):
bridge.set_bridge('torch')
with open(video_path, 'rb') as f:
mp4_stream = f.read()

num_frames = 24
decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))

frame_id_list = None
total_frames = len(decord_vr)
if strategy == 'base':
    clip_end_sec = 60
    clip_start_sec = 0
    start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
    end_frame = min(total_frames,
                    int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
    frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
elif strategy == 'chat':
    timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
    timestamps = [i[0] for i in timestamps]
    max_second = round(max(timestamps)) + 1
    frame_id_list = []
    for second in range(max_second):
        closest_num = min(timestamps, key=lambda x: abs(x - second))
        index = timestamps.index(closest_num)
        frame_id_list.append(index)
        if len(frame_id_list) >= num_frames:
            break

video_data = decord_vr.get_batch(frame_id_list)
video_data = video_data.permute(3, 0, 1, 2)
return video_data

Initialize tokenizer and model

tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)

model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True
).eval().to(DEVICE)

def predict(prompt, video_path, temperature=0.7):
strategy = 'chat'

video = load_video(video_path, strategy=strategy)

history = []
query = prompt
inputs = model.build_conversation_input_ids(
    tokenizer=tokenizer,
    query=query,
    images=[video],
    history=history,
    template_version=strategy
)
inputs = {
    'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
    'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
    'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
    'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
}
gen_kwargs = {
    "max_new_tokens": 2048,
    "pad_token_id": 128002,
    "top_k": 1,
    "do_sample": True,  # Enable sampling to match `temperature` usage
    "top_p": 0.1,
    "temperature": temperature,
}
with torch.no_grad():
    outputs = model.generate(**inputs, **gen_kwargs)
    outputs = outputs[:, inputs['input_ids'].shape[1]:]
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

if name == "main":
response = predict(PROMPT, VIDEO_PATH)
print(response)

TypeError: _extract_past_from_model_output() got an unexpected keyword argument 'standardize_cache_format'

Expected behavior / 期待表现

    cache_name, cache = self._extract_past_from_model_output(
        outputs, standardize_cache_format=standardize_cache_format
    )

我明白需要将上面的改为 cache_name, cache = self._extract_past_from_model_output(
outputs)
但是推理代码中有trust_remote_code=True,自动更新,当trust_remote_code=False后报错安全问题,我尝试了transformers的多个版本包括4.44.0,4.42.4,4.40.2,4.41.2以及最新版本均4.47.0均无法解决,期待得到您的帮助。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions