Skip to content

Commit b2079f2

Browse files
committed
Fix OOB access issue
(detected with ASAN)
1 parent d7b5d2b commit b2079f2

File tree

3 files changed

+27
-65
lines changed

3 files changed

+27
-65
lines changed

internal/TextureUtilsSSE2.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
#if defined(_M_IX86) || defined(_M_X64) || defined(__i386__) || defined(__x86_64__)
22
#include "TextureUtils.h"
33

4-
#ifdef __GNUC__
5-
#pragma GCC push_options
6-
#pragma GCC target("sse2")
7-
#endif
8-
94
#include <cstring>
105

116
#include <immintrin.h>
@@ -36,7 +31,7 @@ static const __m128i RGB_to_RGBA = _mm_set_epi8(-1 /* Insert zero */, 11, 10, 9,
3631

3732
#ifdef __GNUC__
3833
#pragma GCC push_options
39-
#pragma GCC target("avx")
34+
#pragma GCC target("ssse3")
4035
#endif
4136
#ifdef __clang__
4237
#pragma clang attribute push(__attribute__((target("ssse3"))), apply_to = function)
@@ -612,8 +607,4 @@ void EmitAlphaOnlyIndices_SSE2(const uint8_t block[16], const uint8_t min_alpha,
612607

613608
#undef _ABS
614609

615-
#ifdef __GNUC__
616-
#pragma GCC pop_options
617-
#endif
618-
619610
#endif // defined(_M_IX86) || defined(_M_X64) || defined(__i386__) || defined(__x86_64__)

internal/UNetFilter.cpp

Lines changed: 20 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,17 @@ void ReorderWeights_Conv3x3_Direct(const T in_weights[], const int in_channels,
3131
for (int i = 0; i < 3 * in_channels; ++i) {
3232
out_weights[out_index++] = in_weights[j * in_channels * 9 + (i % in_channels) * 9 + (i / in_channels)];
3333
}
34-
for (int i = 3 * in_channels; i < rounded_triple; ++i) {
35-
out_weights[out_index++] = T(0);
36-
}
34+
out_index += rounded_triple - 3 * in_channels;
3735
// line 1
3836
for (int i = 3 * in_channels; i < 6 * in_channels; ++i) {
3937
out_weights[out_index++] = in_weights[j * in_channels * 9 + (i % in_channels) * 9 + (i / in_channels)];
4038
}
41-
for (int i = rounded_triple + 3 * in_channels; i < 2 * rounded_triple; ++i) {
42-
out_weights[out_index++] = T(0);
43-
}
39+
out_index += 2 * rounded_triple - (rounded_triple + 3 * in_channels);
4440
// line 2
4541
for (int i = 6 * in_channels; i < 9 * in_channels; ++i) {
4642
out_weights[out_index++] = in_weights[j * in_channels * 9 + (i % in_channels) * 9 + (i / in_channels)];
4743
}
48-
for (int i = 2 * rounded_triple + 3 * in_channels; i < 3 * rounded_triple; ++i) {
49-
out_weights[out_index++] = T(0);
50-
}
44+
out_index += 3 * rounded_triple - (2 * rounded_triple + 3 * in_channels);
5145
}
5246
}
5347

@@ -65,75 +59,44 @@ void ReorderWeights_Conv3x3_Direct(const T in_weights[], const int in_channels1,
6559
out_weights[out_index++] =
6660
in_weights[j * (in_channels1 + in_channels2) * 9 + (i % in_channels1) * 9 + (i / in_channels1)];
6761
}
68-
for (int i = 3 * in_channels1; i < rounded_triple1; ++i) {
69-
out_weights[out_index++] = T(0);
70-
}
62+
out_index += rounded_triple1 - 3 * in_channels1;
7163
// line 1
7264
for (int i = 3 * in_channels1; i < 6 * in_channels1; ++i) {
7365
out_weights[out_index++] =
7466
in_weights[j * (in_channels1 + in_channels2) * 9 + (i % in_channels1) * 9 + (i / in_channels1)];
7567
}
76-
for (int i = rounded_triple1 + 3 * in_channels1; i < 2 * rounded_triple1; ++i) {
77-
out_weights[out_index++] = T(0);
78-
}
68+
out_index += 2 * rounded_triple1 - (rounded_triple1 + 3 * in_channels1);
7969
// line 2
8070
for (int i = 6 * in_channels1; i < 9 * in_channels1; ++i) {
8171
out_weights[out_index++] =
8272
in_weights[j * (in_channels1 + in_channels2) * 9 + (i % in_channels1) * 9 + (i / in_channels1)];
8373
}
84-
for (int i = 2 * rounded_triple1 + 3 * in_channels1; i < 3 * rounded_triple1; ++i) {
85-
out_weights[out_index++] = T(0);
86-
}
74+
out_index += 3 * rounded_triple1 - (2 * rounded_triple1 + 3 * in_channels1);
8775
assert(out_index == j * 3 * (rounded_triple1 + rounded_triple2) + 3 * rounded_triple1);
8876
// line 0
8977
for (int i = 0; i < 3 * in_channels2; ++i) {
9078
out_weights[out_index++] = in_weights[j * (in_channels1 + in_channels2) * 9 +
9179
(in_channels1 + (i % in_channels2)) * 9 + (i / in_channels2)];
9280
}
93-
for (int i = 3 * in_channels2; i < rounded_triple2; ++i) {
94-
out_weights[out_index++] = T(0);
95-
}
81+
out_index += rounded_triple2 - 3 * in_channels2;
9682
// line 1
9783
for (int i = 3 * in_channels2; i < 6 * in_channels2; ++i) {
9884
out_weights[out_index++] = in_weights[j * (in_channels1 + in_channels2) * 9 +
9985
(in_channels1 + (i % in_channels2)) * 9 + (i / in_channels2)];
10086
}
101-
for (int i = rounded_triple2 + 3 * in_channels2; i < 2 * rounded_triple2; ++i) {
102-
out_weights[out_index++] = T(0);
103-
}
87+
out_index += 2 * rounded_triple2 - (rounded_triple2 + 3 * in_channels2);
10488
// line 2
10589
for (int i = 6 * in_channels2; i < 9 * in_channels2; ++i) {
10690
out_weights[out_index++] = in_weights[j * (in_channels1 + in_channels2) * 9 +
10791
(in_channels1 + (i % in_channels2)) * 9 + (i / in_channels2)];
10892
}
109-
for (int i = 2 * rounded_triple2 + 3 * in_channels2; i < 3 * rounded_triple2; ++i) {
110-
out_weights[out_index++] = T(0);
111-
}
112-
}
113-
}
114-
115-
template <typename T>
116-
void ReorderWeights_Conv3x3_1Direct_2GEMM(const T in_weights[], const int in_channels1, const int in_channels2,
117-
const int out_channels, T out_weights[]) {
118-
for (int j = 0; j < out_channels; ++j) {
119-
for (int c = 0; c < 9; ++c) {
120-
for (int i = 0; i < in_channels1; ++i) {
121-
out_weights[j * (in_channels1 + in_channels2) * 9 + c * in_channels1 + i] =
122-
in_weights[j * (in_channels1 + in_channels2) * 9 + i * 9 + c];
123-
}
124-
}
125-
for (int c = 0; c < 9; ++c) {
126-
for (int i = 0; i < in_channels2; ++i) {
127-
out_weights[j * (in_channels1 + in_channels2) * 9 + 9 * in_channels1 + i * 9 + c] =
128-
in_weights[j * (in_channels1 + in_channels2) * 9 + (in_channels1 + i) * 9 + c];
129-
}
130-
}
93+
out_index += 3 * rounded_triple2 - (2 * rounded_triple2 + 3 * in_channels2);
13194
}
13295
}
13396
} // namespace Ray
13497

135-
int Ray::SetupUNetFilter(int w, int h, bool alias_memory, bool round_w, unet_filter_tensors_t &out_tensors,
136-
SmallVector<int, 2> alias_dependencies[]) {
98+
int Ray::SetupUNetFilter(const int w, const int h, const bool alias_memory, const bool round_w,
99+
unet_filter_tensors_t &out_tensors, SmallVector<int, 2> alias_dependencies[]) {
137100
struct resource_t {
138101
const char *name;
139102
int resolution_div;
@@ -374,7 +337,8 @@ int Ray::SetupUNetWeights(const int alignment, unet_weight_offsets_t *out_offset
374337
return per_output * out_channels;
375338
};
376339

377-
const int input_channels = 9; // color(3), base color(3) and normals(3)
340+
const int input_channels = 9; // color(3), base color(3) and normals(3)
341+
const int output_channels = 3; // hdr color(3)
378342
const int el_align = (256 / sizeof(T));
379343

380344
const int total_count =
@@ -394,7 +358,7 @@ int Ray::SetupUNetWeights(const int alignment, unet_weight_offsets_t *out_offset
394358
round_up_(dec_conv2b_bias.size(), el_align) + round_up(count2(64, input_channels, 64), el_align) +
395359
round_up_(dec_conv1a_bias.size(), el_align) + round_up_(dec_conv1b_weight.size(), el_align) +
396360
round_up_(dec_conv1b_bias.size(), el_align) + round_up_(dec_conv0_weight.size(), el_align) +
397-
int(dec_conv0_bias.size());
361+
round_up_(dec_conv0_bias.size(), el_align);
398362

399363
if (out_offsets) {
400364
out_offsets->enc_conv0_weight = 0;
@@ -432,7 +396,6 @@ int Ray::SetupUNetWeights(const int alignment, unet_weight_offsets_t *out_offset
432396
out_offsets->dec_conv0_weight = out_offsets->dec_conv1b_bias + round_up_(dec_conv1b_bias.size(), el_align);
433397
out_offsets->dec_conv0_bias = out_offsets->dec_conv0_weight + round_up_(dec_conv0_weight.size(), el_align);
434398

435-
assert(out_offsets->dec_conv0_bias + dec_conv0_bias.size() == total_count);
436399
assert((out_offsets->enc_conv0_weight % el_align) == 0 && (out_offsets->enc_conv0_bias % el_align) == 0);
437400
assert((out_offsets->enc_conv1_weight % el_align) == 0 && (out_offsets->enc_conv1_bias % el_align) == 0);
438401
assert((out_offsets->enc_conv2_weight % el_align) == 0 && (out_offsets->enc_conv2_bias % el_align) == 0);
@@ -449,11 +412,14 @@ int Ray::SetupUNetWeights(const int alignment, unet_weight_offsets_t *out_offset
449412
assert((out_offsets->dec_conv1a_weight % el_align) == 0 && (out_offsets->dec_conv1a_bias % el_align) == 0);
450413
assert((out_offsets->dec_conv1b_weight % el_align) == 0 && (out_offsets->dec_conv1b_bias % el_align) == 0);
451414
assert((out_offsets->dec_conv0_weight % el_align) == 0 && (out_offsets->dec_conv0_bias % el_align) == 0);
415+
assert(out_offsets->dec_conv0_bias + round_up_(dec_conv0_bias.size(), el_align) == total_count);
452416
}
453417

454418
if (out_weights) {
455-
std::vector<T> temp;
419+
// Zero out everything to avoid invalid values
420+
std::fill(out_weights, out_weights + total_count, T(0));
456421

422+
std::vector<T> temp;
457423
temp.resize(enc_conv0_weight.size());
458424
for (int i = 0; i < enc_conv0_weight.size(); ++i) {
459425
temp[i] = convert_weight<T>(enc_conv0_weight[i]);
@@ -597,7 +563,8 @@ int Ray::SetupUNetWeights(const int alignment, unet_weight_offsets_t *out_offset
597563
for (int i = 0; i < dec_conv0_weight.size(); ++i) {
598564
temp[i] = convert_weight<T>(dec_conv0_weight[i]);
599565
}
600-
ReorderWeights_Conv3x3_Direct(temp.data(), 32, 3, alignment, &out_weights[out_offsets->dec_conv0_weight]);
566+
ReorderWeights_Conv3x3_Direct(temp.data(), 32, output_channels, alignment,
567+
&out_weights[out_offsets->dec_conv0_weight]);
601568
for (int i = 0; i < dec_conv0_bias.size(); ++i) {
602569
out_weights[out_offsets->dec_conv0_bias + i] = convert_weight<T>(dec_conv0_bias[i]);
603570
}

scripts/analyze_output.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def process_file(f, all_tests):
3333
else:
3434
test['required_samples'] = max(test['required_samples'], required_samples)
3535

36+
def test_failed(key, value):
37+
if 'psnr_tested' in value:
38+
return value['psnr_tested'] < value['psnr_threshold'] or (value['psnr_tested'] - value['psnr_threshold']) > 0.5 or value['fireflies_tested'] > value['fireflies_threshold'] or (value['fireflies_tested'] - value['fireflies_threshold']) > 100
39+
return True
40+
3641
def main():
3742
all_tests = {}
3843

@@ -49,8 +54,7 @@ def main():
4954
print("Failed to process ", sys.argv[i])
5055

5156
# Print failed tests first
52-
condition = lambda key, value: value['psnr_tested'] < value['psnr_threshold'] or (value['psnr_tested'] - value['psnr_threshold']) > 0.5 or value['fireflies_tested'] > value['fireflies_threshold'] or (value['fireflies_tested'] - value['fireflies_threshold']) > 100
53-
failed_tests = {key: value for key, value in all_tests.items() if condition(key, value)}
57+
failed_tests = {key: value for key, value in all_tests.items() if test_failed(key, value)}
5458

5559
print("-- Failed tests --")
5660
print(json.dumps(failed_tests, indent=4))

0 commit comments

Comments
 (0)