@@ -4,13 +4,17 @@ import { deepClone } from 'common/util/deepClone';
4
4
import { stagingAreaImageStaged } from 'features/controlLayers/store/canvasStagingAreaSlice' ;
5
5
import { boardIdSelected , galleryViewChanged , imageSelected , offsetChanged } from 'features/gallery/store/gallerySlice' ;
6
6
import { $nodeExecutionStates , upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState' ;
7
+ import { isImageField , isImageFieldCollection } from 'features/nodes/types/common' ;
7
8
import { zNodeStatus } from 'features/nodes/types/invocation' ;
8
9
import { CANVAS_OUTPUT_PREFIX } from 'features/nodes/util/graph/graphBuilderUtils' ;
10
+ import type { ApiTagDescription } from 'services/api' ;
9
11
import { boardsApi } from 'services/api/endpoints/boards' ;
10
12
import { getImageDTOSafe , imagesApi } from 'services/api/endpoints/images' ;
11
13
import type { ImageDTO , S } from 'services/api/types' ;
12
14
import { getCategories , getListImagesUrl } from 'services/api/util' ;
13
15
import { $lastProgressEvent } from 'services/events/stores' ;
16
+ import type { Param0 } from 'tsafe' ;
17
+ import { objectEntries } from 'tsafe' ;
14
18
import type { JsonObject } from 'type-fest' ;
15
19
16
20
const log = logger ( 'events' ) ;
@@ -22,58 +26,98 @@ const isCanvasOutputNode = (data: S['InvocationCompleteEvent']) => {
22
26
const nodeTypeDenylist = [ 'load_image' , 'image' ] ;
23
27
24
28
export const buildOnInvocationComplete = ( getState : ( ) => RootState , dispatch : AppDispatch ) => {
25
- const addImageToGallery = ( data : S [ 'InvocationCompleteEvent' ] , imageDTO : ImageDTO ) => {
29
+ const addImagesToGallery = ( data : S [ 'InvocationCompleteEvent' ] , imageDTOs : ImageDTO [ ] ) => {
26
30
if ( nodeTypeDenylist . includes ( data . invocation . type ) ) {
27
- log . trace ( ' Skipping node type denylisted' ) ;
31
+ log . trace ( ` Skipping denylisted node type ( ${ data . invocation . type } )` ) ;
28
32
return ;
29
33
}
30
34
31
- if ( imageDTO . is_intermediate ) {
35
+ // For efficiency's sake, we want to minimize the number of dispatches and invalidations we do.
36
+ // We'll keep track of each change we need to make and do them all at once.
37
+ const boardTotalAdditions : Record < string , number > = { } ;
38
+ const boardTagIdsToInvalidate : Set < string > = new Set ( ) ;
39
+ const imageListTagIdsToInvalidate : Set < string > = new Set ( ) ;
40
+
41
+ for ( const imageDTO of imageDTOs ) {
42
+ if ( imageDTO . is_intermediate ) {
43
+ return ;
44
+ }
45
+
46
+ const boardId = imageDTO . board_id ?? 'none' ;
47
+ // update the total images for the board
48
+ boardTotalAdditions [ boardId ] = ( boardTotalAdditions [ boardId ] || 0 ) + 1 ;
49
+ // invalidate the board tag
50
+ boardTagIdsToInvalidate . add ( boardId ) ;
51
+ // invalidate the image list tag
52
+ imageListTagIdsToInvalidate . add (
53
+ getListImagesUrl ( {
54
+ board_id : boardId ,
55
+ categories : getCategories ( imageDTO ) ,
56
+ } )
57
+ ) ;
58
+ }
59
+
60
+ // Update all the board image totals at once
61
+ const entries : Param0 < typeof boardsApi . util . upsertQueryEntries > = [ ] ;
62
+ for ( const [ boardId , amountToAdd ] of objectEntries ( boardTotalAdditions ) ) {
63
+ // upsertQueryEntries doesn't provide a "recipe" function for the update - we must provide the new value
64
+ // directly. So we need to select the board totals first.
65
+ const total = boardsApi . endpoints . getBoardImagesTotal . select ( boardId ) ( getState ( ) ) . data ?. total ;
66
+ if ( total === undefined ) {
67
+ // No cache exists for this board, so we can't update it.
68
+ continue ;
69
+ }
70
+ entries . push ( {
71
+ endpointName : 'getBoardImagesTotal' ,
72
+ arg : boardId ,
73
+ value : { total : total + amountToAdd } ,
74
+ } ) ;
75
+ }
76
+ dispatch ( boardsApi . util . upsertQueryEntries ( entries ) ) ;
77
+
78
+ // Invalidate all tags at once
79
+ const boardTags : ApiTagDescription [ ] = Array . from ( boardTagIdsToInvalidate ) . map ( ( boardId ) => ( {
80
+ type : 'Board' as const ,
81
+ id : boardId ,
82
+ } ) ) ;
83
+ const imageListTags : ApiTagDescription [ ] = Array . from ( imageListTagIdsToInvalidate ) . map ( ( imageListId ) => ( {
84
+ type : 'ImageList' as const ,
85
+ id : imageListId ,
86
+ } ) ) ;
87
+ dispatch ( imagesApi . util . invalidateTags ( [ ...boardTags , ...imageListTags ] ) ) ;
88
+
89
+ // Finally, we may need to autoswitch to the new image. We'll only do it for the last image in the list.
90
+
91
+ const lastImageDTO = imageDTOs . at ( - 1 ) ;
92
+
93
+ if ( ! lastImageDTO ) {
32
94
return ;
33
95
}
34
96
35
- // update the total images for the board
36
- dispatch (
37
- boardsApi . util . updateQueryData ( 'getBoardImagesTotal' , imageDTO . board_id ?? 'none' , ( draft ) => {
38
- draft . total += 1 ;
39
- } )
40
- ) ;
41
-
42
- dispatch (
43
- imagesApi . util . invalidateTags ( [
44
- { type : 'Board' , id : imageDTO . board_id ?? 'none' } ,
45
- {
46
- type : 'ImageList' ,
47
- id : getListImagesUrl ( {
48
- board_id : imageDTO . board_id ?? 'none' ,
49
- categories : getCategories ( imageDTO ) ,
50
- } ) ,
51
- } ,
52
- ] )
53
- ) ;
97
+ const { image_name, board_id } = lastImageDTO ;
54
98
55
99
const { shouldAutoSwitch, selectedBoardId, galleryView, offset } = getState ( ) . gallery ;
56
100
57
101
// If auto-switch is enabled, select the new image
58
102
if ( shouldAutoSwitch ) {
59
103
// If the image is from a different board, switch to that board - this will also select the image
60
- if ( imageDTO . board_id && imageDTO . board_id !== selectedBoardId ) {
104
+ if ( board_id && board_id !== selectedBoardId ) {
61
105
dispatch (
62
106
boardIdSelected ( {
63
- boardId : imageDTO . board_id ,
64
- selectedImageName : imageDTO . image_name ,
107
+ boardId : board_id ,
108
+ selectedImageName : image_name ,
65
109
} )
66
110
) ;
67
- } else if ( ! imageDTO . board_id && selectedBoardId !== 'none' ) {
111
+ } else if ( ! board_id && selectedBoardId !== 'none' ) {
68
112
dispatch (
69
113
boardIdSelected ( {
70
114
boardId : 'none' ,
71
- selectedImageName : imageDTO . image_name ,
115
+ selectedImageName : image_name ,
72
116
} )
73
117
) ;
74
118
} else {
75
119
// Else just select the image, no need to switch boards
76
- dispatch ( imageSelected ( imageDTO ) ) ;
120
+ dispatch ( imageSelected ( lastImageDTO ) ) ;
77
121
78
122
if ( galleryView !== 'images' ) {
79
123
// We also need to update the gallery view to images. This also updates the offset.
@@ -86,12 +130,25 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
86
130
}
87
131
} ;
88
132
89
- const getResultImageDTO = ( data : S [ 'InvocationCompleteEvent' ] ) => {
133
+ const getResultImageDTOs = async ( data : S [ 'InvocationCompleteEvent' ] ) : Promise < ImageDTO [ ] > => {
90
134
const { result } = data ;
91
- if ( result . type === 'image_output' ) {
92
- return getImageDTOSafe ( result . image . image_name ) ;
135
+ const imageDTOs : ImageDTO [ ] = [ ] ;
136
+ for ( const [ _name , value ] of objectEntries ( result ) ) {
137
+ if ( isImageField ( value ) ) {
138
+ const imageDTO = await getImageDTOSafe ( value . image_name ) ;
139
+ if ( imageDTO ) {
140
+ imageDTOs . push ( imageDTO ) ;
141
+ }
142
+ } else if ( isImageFieldCollection ( value ) ) {
143
+ for ( const imageField of value ) {
144
+ const imageDTO = await getImageDTOSafe ( imageField . image_name ) ;
145
+ if ( imageDTO ) {
146
+ imageDTOs . push ( imageDTO ) ;
147
+ }
148
+ }
149
+ }
93
150
}
94
- return null ;
151
+ return imageDTOs ;
95
152
} ;
96
153
97
154
const handleOriginWorkflows = async ( data : S [ 'InvocationCompleteEvent' ] ) => {
@@ -107,16 +164,15 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
107
164
upsertExecutionState ( nes . nodeId , nes ) ;
108
165
}
109
166
110
- const imageDTO = await getResultImageDTO ( data ) ;
111
-
112
- if ( imageDTO && ! imageDTO . is_intermediate ) {
113
- addImageToGallery ( data , imageDTO ) ;
114
- }
167
+ const imageDTOs = await getResultImageDTOs ( data ) ;
168
+ addImagesToGallery ( data , imageDTOs ) ;
115
169
} ;
116
170
117
171
const handleOriginCanvas = async ( data : S [ 'InvocationCompleteEvent' ] ) => {
118
- const imageDTO = await getResultImageDTO ( data ) ;
172
+ const imageDTOs = await getResultImageDTOs ( data ) ;
119
173
174
+ // We expect only a single image in the canvas output
175
+ const imageDTO = imageDTOs [ 0 ] ;
120
176
if ( ! imageDTO ) {
121
177
return ;
122
178
}
@@ -127,20 +183,17 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
127
183
if ( data . result . type === 'image_output' ) {
128
184
dispatch ( stagingAreaImageStaged ( { stagingAreaImage : { imageDTO, offsetX : 0 , offsetY : 0 } } ) ) ;
129
185
}
130
- addImageToGallery ( data , imageDTO ) ;
186
+ addImagesToGallery ( data , [ imageDTO ] ) ;
131
187
}
132
188
} else if ( ! imageDTO . is_intermediate ) {
133
189
// Desintaion is gallery
134
- addImageToGallery ( data , imageDTO ) ;
190
+ addImagesToGallery ( data , [ imageDTO ] ) ;
135
191
}
136
192
} ;
137
193
138
194
const handleOriginOther = async ( data : S [ 'InvocationCompleteEvent' ] ) => {
139
- const imageDTO = await getResultImageDTO ( data ) ;
140
-
141
- if ( imageDTO && ! imageDTO . is_intermediate ) {
142
- addImageToGallery ( data , imageDTO ) ;
143
- }
195
+ const imageDTOs = await getResultImageDTOs ( data ) ;
196
+ addImagesToGallery ( data , imageDTOs ) ;
144
197
} ;
145
198
146
199
return async ( data : S [ 'InvocationCompleteEvent' ] ) => {
0 commit comments