forked from ZebangCheng/Emotion-LLaMA
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfer_api.py
More file actions
76 lines (59 loc) · 2.24 KB
/
infer_api.py
File metadata and controls
76 lines (59 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import numpy as np
import cv2
import io
import base64
import requests
from PIL import Image
from io import BytesIO
import json
requests.adapters.DEFAULT_RETRIES = 5 # 增加重连次数
def get_pil_image_return_image_base64(raw_image_data):
if raw_image_data is None:
print("raw_image_data is None")
return None
elif isinstance(raw_image_data, dict) and "bytes" in raw_image_data:
if isinstance(raw_image_data["bytes"], bytes):
return base64.b64encode(raw_image_data["bytes"]).decode('utf-8')
else:
raise ValueError("'bytes' key does not contain a bytes object")
elif isinstance(raw_image_data, Image.Image):
try:
if raw_image_data.mode != "RGB":
raw_image_data = raw_image_data.convert("RGB")
buffered = io.BytesIO()
raw_image_data.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
# to Base64
return base64.b64encode(img_bytes).decode('utf-8')
except Exception as e:
raise ValueError(f"Failed to encode PIL image: {str(e)}")
elif isinstance(raw_image_data, str):
return raw_image_data
elif isinstance(raw_image_data, list):
raw_image_data = raw_image_data[0]
return get_pil_image_return_image_base64(raw_image_data)
else:
raise ValueError("Unsupported image data format")
temp_dir = "./temp"
video_path = "path/to/video.mp4"
# prompt = "分析视频中人物的情绪。"
# prompt = "What emotions does the video convey?"
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
# 抽帧
cap = cv2.VideoCapture(video_path)
success, frame = cap.read() # 读取视频的第一帧
if not success:
print(f"Failed to read video {video_name}.")
image_path = os.path.join(temp_dir, "temp_infer_EmotionLLaMA.jpg")
cv2.imwrite(image_path, frame)
url = "http://10.14.3.47:7889/api/predict/" #change this to your own url
headers = {"Content-Type": "application/json"}
raw_image_data = Image.open(image_path)
image = get_pil_image_return_image_base64(raw_image_data)
data = {
"data": [image, prompt]
}
response = requests.post(url, headers=headers, data=json.dumps(data)).json()
print(response['data'][0])