Skip to content

Commit 3988128

Browse files
feat(ui): add _all_ image outputs to gallery (including collections)
1 parent c768f47 commit 3988128

File tree

2 files changed

+103
-45
lines changed

2 files changed

+103
-45
lines changed

invokeai/frontend/web/src/features/nodes/types/common.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ export const zImageField = z.object({
88
image_name: z.string().trim().min(1),
99
});
1010
export type ImageField = z.infer<typeof zImageField>;
11+
export const isImageField = (field: unknown): field is ImageField => zImageField.safeParse(field).success;
12+
const zImageFieldCollection = z.array(zImageField);
13+
type ImageFieldCollection = z.infer<typeof zImageFieldCollection>;
14+
export const isImageFieldCollection = (field: unknown): field is ImageFieldCollection =>
15+
zImageFieldCollection.safeParse(field).success;
1116

1217
export const zBoardField = z.object({
1318
board_id: z.string().trim().min(1),

invokeai/frontend/web/src/services/events/onInvocationComplete.tsx

Lines changed: 98 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@ import { deepClone } from 'common/util/deepClone';
44
import { stagingAreaImageStaged } from 'features/controlLayers/store/canvasStagingAreaSlice';
55
import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
66
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
7+
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
78
import { zNodeStatus } from 'features/nodes/types/invocation';
89
import { CANVAS_OUTPUT_PREFIX } from 'features/nodes/util/graph/graphBuilderUtils';
10+
import type { ApiTagDescription } from 'services/api';
911
import { boardsApi } from 'services/api/endpoints/boards';
1012
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
1113
import type { ImageDTO, S } from 'services/api/types';
1214
import { getCategories, getListImagesUrl } from 'services/api/util';
1315
import { $lastProgressEvent } from 'services/events/stores';
16+
import type { Param0 } from 'tsafe';
17+
import { objectEntries } from 'tsafe';
1418
import type { JsonObject } from 'type-fest';
1519

1620
const log = logger('events');
@@ -22,58 +26,98 @@ const isCanvasOutputNode = (data: S['InvocationCompleteEvent']) => {
2226
const nodeTypeDenylist = ['load_image', 'image'];
2327

2428
export const buildOnInvocationComplete = (getState: () => RootState, dispatch: AppDispatch) => {
25-
const addImageToGallery = (data: S['InvocationCompleteEvent'], imageDTO: ImageDTO) => {
29+
const addImagesToGallery = (data: S['InvocationCompleteEvent'], imageDTOs: ImageDTO[]) => {
2630
if (nodeTypeDenylist.includes(data.invocation.type)) {
27-
log.trace('Skipping node type denylisted');
31+
log.trace(`Skipping denylisted node type (${data.invocation.type})`);
2832
return;
2933
}
3034

31-
if (imageDTO.is_intermediate) {
35+
// For efficiency's sake, we want to minimize the number of dispatches and invalidations we do.
36+
// We'll keep track of each change we need to make and do them all at once.
37+
const boardTotalAdditions: Record<string, number> = {};
38+
const boardTagIdsToInvalidate: Set<string> = new Set();
39+
const imageListTagIdsToInvalidate: Set<string> = new Set();
40+
41+
for (const imageDTO of imageDTOs) {
42+
if (imageDTO.is_intermediate) {
43+
return;
44+
}
45+
46+
const boardId = imageDTO.board_id ?? 'none';
47+
// update the total images for the board
48+
boardTotalAdditions[boardId] = (boardTotalAdditions[boardId] || 0) + 1;
49+
// invalidate the board tag
50+
boardTagIdsToInvalidate.add(boardId);
51+
// invalidate the image list tag
52+
imageListTagIdsToInvalidate.add(
53+
getListImagesUrl({
54+
board_id: boardId,
55+
categories: getCategories(imageDTO),
56+
})
57+
);
58+
}
59+
60+
// Update all the board image totals at once
61+
const entries: Param0<typeof boardsApi.util.upsertQueryEntries> = [];
62+
for (const [boardId, amountToAdd] of objectEntries(boardTotalAdditions)) {
63+
// upsertQueryEntries doesn't provide a "recipe" function for the update - we must provide the new value
64+
// directly. So we need to select the board totals first.
65+
const total = boardsApi.endpoints.getBoardImagesTotal.select(boardId)(getState()).data?.total;
66+
if (total === undefined) {
67+
// No cache exists for this board, so we can't update it.
68+
continue;
69+
}
70+
entries.push({
71+
endpointName: 'getBoardImagesTotal',
72+
arg: boardId,
73+
value: { total: total + amountToAdd },
74+
});
75+
}
76+
dispatch(boardsApi.util.upsertQueryEntries(entries));
77+
78+
// Invalidate all tags at once
79+
const boardTags: ApiTagDescription[] = Array.from(boardTagIdsToInvalidate).map((boardId) => ({
80+
type: 'Board' as const,
81+
id: boardId,
82+
}));
83+
const imageListTags: ApiTagDescription[] = Array.from(imageListTagIdsToInvalidate).map((imageListId) => ({
84+
type: 'ImageList' as const,
85+
id: imageListId,
86+
}));
87+
dispatch(imagesApi.util.invalidateTags([...boardTags, ...imageListTags]));
88+
89+
// Finally, we may need to autoswitch to the new image. We'll only do it for the last image in the list.
90+
91+
const lastImageDTO = imageDTOs.at(-1);
92+
93+
if (!lastImageDTO) {
3294
return;
3395
}
3496

35-
// update the total images for the board
36-
dispatch(
37-
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
38-
draft.total += 1;
39-
})
40-
);
41-
42-
dispatch(
43-
imagesApi.util.invalidateTags([
44-
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
45-
{
46-
type: 'ImageList',
47-
id: getListImagesUrl({
48-
board_id: imageDTO.board_id ?? 'none',
49-
categories: getCategories(imageDTO),
50-
}),
51-
},
52-
])
53-
);
97+
const { image_name, board_id } = lastImageDTO;
5498

5599
const { shouldAutoSwitch, selectedBoardId, galleryView, offset } = getState().gallery;
56100

57101
// If auto-switch is enabled, select the new image
58102
if (shouldAutoSwitch) {
59103
// If the image is from a different board, switch to that board - this will also select the image
60-
if (imageDTO.board_id && imageDTO.board_id !== selectedBoardId) {
104+
if (board_id && board_id !== selectedBoardId) {
61105
dispatch(
62106
boardIdSelected({
63-
boardId: imageDTO.board_id,
64-
selectedImageName: imageDTO.image_name,
107+
boardId: board_id,
108+
selectedImageName: image_name,
65109
})
66110
);
67-
} else if (!imageDTO.board_id && selectedBoardId !== 'none') {
111+
} else if (!board_id && selectedBoardId !== 'none') {
68112
dispatch(
69113
boardIdSelected({
70114
boardId: 'none',
71-
selectedImageName: imageDTO.image_name,
115+
selectedImageName: image_name,
72116
})
73117
);
74118
} else {
75119
// Else just select the image, no need to switch boards
76-
dispatch(imageSelected(imageDTO));
120+
dispatch(imageSelected(lastImageDTO));
77121

78122
if (galleryView !== 'images') {
79123
// We also need to update the gallery view to images. This also updates the offset.
@@ -86,12 +130,25 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
86130
}
87131
};
88132

89-
const getResultImageDTO = (data: S['InvocationCompleteEvent']) => {
133+
const getResultImageDTOs = async (data: S['InvocationCompleteEvent']): Promise<ImageDTO[]> => {
90134
const { result } = data;
91-
if (result.type === 'image_output') {
92-
return getImageDTOSafe(result.image.image_name);
135+
const imageDTOs: ImageDTO[] = [];
136+
for (const [_name, value] of objectEntries(result)) {
137+
if (isImageField(value)) {
138+
const imageDTO = await getImageDTOSafe(value.image_name);
139+
if (imageDTO) {
140+
imageDTOs.push(imageDTO);
141+
}
142+
} else if (isImageFieldCollection(value)) {
143+
for (const imageField of value) {
144+
const imageDTO = await getImageDTOSafe(imageField.image_name);
145+
if (imageDTO) {
146+
imageDTOs.push(imageDTO);
147+
}
148+
}
149+
}
93150
}
94-
return null;
151+
return imageDTOs;
95152
};
96153

97154
const handleOriginWorkflows = async (data: S['InvocationCompleteEvent']) => {
@@ -107,16 +164,15 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
107164
upsertExecutionState(nes.nodeId, nes);
108165
}
109166

110-
const imageDTO = await getResultImageDTO(data);
111-
112-
if (imageDTO && !imageDTO.is_intermediate) {
113-
addImageToGallery(data, imageDTO);
114-
}
167+
const imageDTOs = await getResultImageDTOs(data);
168+
addImagesToGallery(data, imageDTOs);
115169
};
116170

117171
const handleOriginCanvas = async (data: S['InvocationCompleteEvent']) => {
118-
const imageDTO = await getResultImageDTO(data);
172+
const imageDTOs = await getResultImageDTOs(data);
119173

174+
// We expect only a single image in the canvas output
175+
const imageDTO = imageDTOs[0];
120176
if (!imageDTO) {
121177
return;
122178
}
@@ -127,20 +183,17 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
127183
if (data.result.type === 'image_output') {
128184
dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
129185
}
130-
addImageToGallery(data, imageDTO);
186+
addImagesToGallery(data, [imageDTO]);
131187
}
132188
} else if (!imageDTO.is_intermediate) {
133189
// Desintaion is gallery
134-
addImageToGallery(data, imageDTO);
190+
addImagesToGallery(data, [imageDTO]);
135191
}
136192
};
137193

138194
const handleOriginOther = async (data: S['InvocationCompleteEvent']) => {
139-
const imageDTO = await getResultImageDTO(data);
140-
141-
if (imageDTO && !imageDTO.is_intermediate) {
142-
addImageToGallery(data, imageDTO);
143-
}
195+
const imageDTOs = await getResultImageDTOs(data);
196+
addImagesToGallery(data, imageDTOs);
144197
};
145198

146199
return async (data: S['InvocationCompleteEvent']) => {

0 commit comments

Comments
 (0)