@@ -31,23 +31,17 @@ void ReorderWeights_Conv3x3_Direct(const T in_weights[], const int in_channels,
3131 for (int i = 0 ; i < 3 * in_channels; ++i) {
3232 out_weights[out_index++] = in_weights[j * in_channels * 9 + (i % in_channels) * 9 + (i / in_channels)];
3333 }
34- for (int i = 3 * in_channels; i < rounded_triple; ++i) {
35- out_weights[out_index++] = T (0 );
36- }
34+ out_index += rounded_triple - 3 * in_channels;
3735 // line 1
3836 for (int i = 3 * in_channels; i < 6 * in_channels; ++i) {
3937 out_weights[out_index++] = in_weights[j * in_channels * 9 + (i % in_channels) * 9 + (i / in_channels)];
4038 }
41- for (int i = rounded_triple + 3 * in_channels; i < 2 * rounded_triple; ++i) {
42- out_weights[out_index++] = T (0 );
43- }
39+ out_index += 2 * rounded_triple - (rounded_triple + 3 * in_channels);
4440 // line 2
4541 for (int i = 6 * in_channels; i < 9 * in_channels; ++i) {
4642 out_weights[out_index++] = in_weights[j * in_channels * 9 + (i % in_channels) * 9 + (i / in_channels)];
4743 }
48- for (int i = 2 * rounded_triple + 3 * in_channels; i < 3 * rounded_triple; ++i) {
49- out_weights[out_index++] = T (0 );
50- }
44+ out_index += 3 * rounded_triple - (2 * rounded_triple + 3 * in_channels);
5145 }
5246}
5347
@@ -65,75 +59,44 @@ void ReorderWeights_Conv3x3_Direct(const T in_weights[], const int in_channels1,
6559 out_weights[out_index++] =
6660 in_weights[j * (in_channels1 + in_channels2) * 9 + (i % in_channels1) * 9 + (i / in_channels1)];
6761 }
68- for (int i = 3 * in_channels1; i < rounded_triple1; ++i) {
69- out_weights[out_index++] = T (0 );
70- }
62+ out_index += rounded_triple1 - 3 * in_channels1;
7163 // line 1
7264 for (int i = 3 * in_channels1; i < 6 * in_channels1; ++i) {
7365 out_weights[out_index++] =
7466 in_weights[j * (in_channels1 + in_channels2) * 9 + (i % in_channels1) * 9 + (i / in_channels1)];
7567 }
76- for (int i = rounded_triple1 + 3 * in_channels1; i < 2 * rounded_triple1; ++i) {
77- out_weights[out_index++] = T (0 );
78- }
68+ out_index += 2 * rounded_triple1 - (rounded_triple1 + 3 * in_channels1);
7969 // line 2
8070 for (int i = 6 * in_channels1; i < 9 * in_channels1; ++i) {
8171 out_weights[out_index++] =
8272 in_weights[j * (in_channels1 + in_channels2) * 9 + (i % in_channels1) * 9 + (i / in_channels1)];
8373 }
84- for (int i = 2 * rounded_triple1 + 3 * in_channels1; i < 3 * rounded_triple1; ++i) {
85- out_weights[out_index++] = T (0 );
86- }
74+ out_index += 3 * rounded_triple1 - (2 * rounded_triple1 + 3 * in_channels1);
8775 assert (out_index == j * 3 * (rounded_triple1 + rounded_triple2) + 3 * rounded_triple1);
8876 // line 0
8977 for (int i = 0 ; i < 3 * in_channels2; ++i) {
9078 out_weights[out_index++] = in_weights[j * (in_channels1 + in_channels2) * 9 +
9179 (in_channels1 + (i % in_channels2)) * 9 + (i / in_channels2)];
9280 }
93- for (int i = 3 * in_channels2; i < rounded_triple2; ++i) {
94- out_weights[out_index++] = T (0 );
95- }
81+ out_index += rounded_triple2 - 3 * in_channels2;
9682 // line 1
9783 for (int i = 3 * in_channels2; i < 6 * in_channels2; ++i) {
9884 out_weights[out_index++] = in_weights[j * (in_channels1 + in_channels2) * 9 +
9985 (in_channels1 + (i % in_channels2)) * 9 + (i / in_channels2)];
10086 }
101- for (int i = rounded_triple2 + 3 * in_channels2; i < 2 * rounded_triple2; ++i) {
102- out_weights[out_index++] = T (0 );
103- }
87+ out_index += 2 * rounded_triple2 - (rounded_triple2 + 3 * in_channels2);
10488 // line 2
10589 for (int i = 6 * in_channels2; i < 9 * in_channels2; ++i) {
10690 out_weights[out_index++] = in_weights[j * (in_channels1 + in_channels2) * 9 +
10791 (in_channels1 + (i % in_channels2)) * 9 + (i / in_channels2)];
10892 }
109- for (int i = 2 * rounded_triple2 + 3 * in_channels2; i < 3 * rounded_triple2; ++i) {
110- out_weights[out_index++] = T (0 );
111- }
112- }
113- }
114-
115- template <typename T>
116- void ReorderWeights_Conv3x3_1Direct_2GEMM (const T in_weights[], const int in_channels1, const int in_channels2,
117- const int out_channels, T out_weights[]) {
118- for (int j = 0 ; j < out_channels; ++j) {
119- for (int c = 0 ; c < 9 ; ++c) {
120- for (int i = 0 ; i < in_channels1; ++i) {
121- out_weights[j * (in_channels1 + in_channels2) * 9 + c * in_channels1 + i] =
122- in_weights[j * (in_channels1 + in_channels2) * 9 + i * 9 + c];
123- }
124- }
125- for (int c = 0 ; c < 9 ; ++c) {
126- for (int i = 0 ; i < in_channels2; ++i) {
127- out_weights[j * (in_channels1 + in_channels2) * 9 + 9 * in_channels1 + i * 9 + c] =
128- in_weights[j * (in_channels1 + in_channels2) * 9 + (in_channels1 + i) * 9 + c];
129- }
130- }
93+ out_index += 3 * rounded_triple2 - (2 * rounded_triple2 + 3 * in_channels2);
13194 }
13295}
13396} // namespace Ray
13497
135- int Ray::SetupUNetFilter (int w, int h, bool alias_memory, bool round_w, unet_filter_tensors_t &out_tensors ,
136- SmallVector<int , 2 > alias_dependencies[]) {
98+ int Ray::SetupUNetFilter (const int w, const int h, const bool alias_memory, const bool round_w,
99+ unet_filter_tensors_t &out_tensors, SmallVector<int , 2 > alias_dependencies[]) {
137100 struct resource_t {
138101 const char *name;
139102 int resolution_div;
@@ -374,7 +337,8 @@ int Ray::SetupUNetWeights(const int alignment, unet_weight_offsets_t *out_offset
374337 return per_output * out_channels;
375338 };
376339
377- const int input_channels = 9 ; // color(3), base color(3) and normals(3)
340+ const int input_channels = 9 ; // color(3), base color(3) and normals(3)
341+ const int output_channels = 3 ; // hdr color(3)
378342 const int el_align = (256 / sizeof (T));
379343
380344 const int total_count =
@@ -394,7 +358,7 @@ int Ray::SetupUNetWeights(const int alignment, unet_weight_offsets_t *out_offset
394358 round_up_ (dec_conv2b_bias.size (), el_align) + round_up (count2 (64 , input_channels, 64 ), el_align) +
395359 round_up_ (dec_conv1a_bias.size (), el_align) + round_up_ (dec_conv1b_weight.size (), el_align) +
396360 round_up_ (dec_conv1b_bias.size (), el_align) + round_up_ (dec_conv0_weight.size (), el_align) +
397- int (dec_conv0_bias.size ());
361+ round_up_ (dec_conv0_bias.size (), el_align );
398362
399363 if (out_offsets) {
400364 out_offsets->enc_conv0_weight = 0 ;
@@ -432,7 +396,6 @@ int Ray::SetupUNetWeights(const int alignment, unet_weight_offsets_t *out_offset
432396 out_offsets->dec_conv0_weight = out_offsets->dec_conv1b_bias + round_up_ (dec_conv1b_bias.size (), el_align);
433397 out_offsets->dec_conv0_bias = out_offsets->dec_conv0_weight + round_up_ (dec_conv0_weight.size (), el_align);
434398
435- assert (out_offsets->dec_conv0_bias + dec_conv0_bias.size () == total_count);
436399 assert ((out_offsets->enc_conv0_weight % el_align) == 0 && (out_offsets->enc_conv0_bias % el_align) == 0 );
437400 assert ((out_offsets->enc_conv1_weight % el_align) == 0 && (out_offsets->enc_conv1_bias % el_align) == 0 );
438401 assert ((out_offsets->enc_conv2_weight % el_align) == 0 && (out_offsets->enc_conv2_bias % el_align) == 0 );
@@ -449,11 +412,14 @@ int Ray::SetupUNetWeights(const int alignment, unet_weight_offsets_t *out_offset
449412 assert ((out_offsets->dec_conv1a_weight % el_align) == 0 && (out_offsets->dec_conv1a_bias % el_align) == 0 );
450413 assert ((out_offsets->dec_conv1b_weight % el_align) == 0 && (out_offsets->dec_conv1b_bias % el_align) == 0 );
451414 assert ((out_offsets->dec_conv0_weight % el_align) == 0 && (out_offsets->dec_conv0_bias % el_align) == 0 );
415+ assert (out_offsets->dec_conv0_bias + round_up_ (dec_conv0_bias.size (), el_align) == total_count);
452416 }
453417
454418 if (out_weights) {
455- std::vector<T> temp;
419+ // Zero out everything to avoid invalid values
420+ std::fill (out_weights, out_weights + total_count, T (0 ));
456421
422+ std::vector<T> temp;
457423 temp.resize (enc_conv0_weight.size ());
458424 for (int i = 0 ; i < enc_conv0_weight.size (); ++i) {
459425 temp[i] = convert_weight<T>(enc_conv0_weight[i]);
@@ -597,7 +563,8 @@ int Ray::SetupUNetWeights(const int alignment, unet_weight_offsets_t *out_offset
597563 for (int i = 0 ; i < dec_conv0_weight.size (); ++i) {
598564 temp[i] = convert_weight<T>(dec_conv0_weight[i]);
599565 }
600- ReorderWeights_Conv3x3_Direct (temp.data (), 32 , 3 , alignment, &out_weights[out_offsets->dec_conv0_weight ]);
566+ ReorderWeights_Conv3x3_Direct (temp.data (), 32 , output_channels, alignment,
567+ &out_weights[out_offsets->dec_conv0_weight ]);
601568 for (int i = 0 ; i < dec_conv0_bias.size (); ++i) {
602569 out_weights[out_offsets->dec_conv0_bias + i] = convert_weight<T>(dec_conv0_bias[i]);
603570 }
0 commit comments