@@ -14,38 +14,33 @@ pub(crate) fn create_key(args: CreateKeyArgs) -> anyhow::Result<()> {
14
14
crate :: core:: signing:: create_key ( & args. private_key , & args. public_key )
15
15
}
16
16
17
+ fn get_paths_for ( format : Option < FileType > , file_path : & Path ) -> anyhow:: Result < Vec < PathBuf > > {
18
+ // determine handler
19
+ let handler = crate :: core:: handlers:: handler_for ( format, file_path, Scope :: Signing ) ;
20
+ // get the paths to sign or verify
21
+ if let Ok ( handler) = handler {
22
+ handler. paths_to_sign ( file_path)
23
+ } else {
24
+ Ok ( vec ! [ file_path. to_path_buf( ) ] )
25
+ }
26
+ }
27
+
17
28
fn get_paths_of_interest (
18
29
format : Option < FileType > ,
19
30
file_path : & Path ,
20
31
) -> anyhow:: Result < Vec < PathBuf > > {
21
32
let paths = if file_path. is_file ( ) {
22
33
// single file case
23
- // determine handler
24
- let handler = crate :: core:: handlers:: handler_for ( format, file_path, Scope :: Signing ) ;
25
- // get the paths to sign or verify
26
- if let Ok ( handler) = handler {
27
- handler. paths_to_sign ( file_path) ?
28
- } else {
29
- println ! ( "Warning: Unrecognized file format. Signing this file does not ensure that the model data will be signed in its entirety." ) ;
30
- vec ! [ file_path. to_path_buf( ) ]
31
- }
34
+ get_paths_for ( format, file_path) ?
32
35
} else {
33
36
let mut unique = HashSet :: new ( ) ;
34
37
35
38
// collect all files in the directory
36
- for entry in glob ( file_path. join ( "**/*.* " ) . to_str ( ) . unwrap ( ) ) ? {
39
+ for entry in glob ( file_path. join ( "**/*" ) . to_str ( ) . unwrap ( ) ) ? {
37
40
match entry {
38
41
Ok ( path) => {
39
42
if path. is_file ( ) {
40
- // determine handler
41
- if let Ok ( handler) = crate :: core:: handlers:: handler_for (
42
- format. clone ( ) ,
43
- & path,
44
- Scope :: Signing ,
45
- ) {
46
- // add only if handled
47
- unique. extend ( handler. paths_to_sign ( & path) ?) ;
48
- }
43
+ unique. extend ( get_paths_for ( format. clone ( ) , & path) ?) ;
49
44
}
50
45
}
51
46
Err ( e) => println ! ( "{:?}" , e) ,
@@ -125,3 +120,142 @@ pub(crate) fn verify(args: VerifyArgs) -> anyhow::Result<()> {
125
120
126
121
Ok ( ( ) )
127
122
}
123
+
124
+ #[ cfg( test) ]
125
+ mod tests {
126
+ use super :: * ;
127
+ use std:: fs:: File ;
128
+ use tempfile:: TempDir ;
129
+
130
+ #[ test]
131
+ fn test_get_paths_single_file ( ) -> anyhow:: Result < ( ) > {
132
+ let temp_dir = TempDir :: new ( ) ?;
133
+ let file_path = temp_dir. path ( ) . join ( "model.safetensors" ) ;
134
+ File :: create ( & file_path) ?;
135
+
136
+ let paths = get_paths_of_interest ( None , & file_path) ?;
137
+ assert_eq ! ( paths. len( ) , 1 ) ;
138
+ assert_eq ! ( paths[ 0 ] , file_path) ;
139
+
140
+ Ok ( ( ) )
141
+ }
142
+
143
+ #[ test]
144
+ fn test_get_paths_directory ( ) -> anyhow:: Result < ( ) > {
145
+ let temp_dir = TempDir :: new ( ) ?;
146
+
147
+ // Create multiple files
148
+ File :: create ( temp_dir. path ( ) . join ( "model.safetensors" ) ) ?;
149
+ File :: create ( temp_dir. path ( ) . join ( "model.bin" ) ) ?;
150
+ File :: create ( temp_dir. path ( ) . join ( "other.txt" ) ) ?;
151
+
152
+ let paths = get_paths_of_interest ( None , temp_dir. path ( ) ) ?;
153
+ assert_eq ! ( paths. len( ) , 3 ) ;
154
+
155
+ // Sort paths for consistent comparison
156
+ let mut paths: Vec < String > = paths
157
+ . iter ( )
158
+ . map ( |p| p. file_name ( ) . unwrap ( ) . to_string_lossy ( ) . to_string ( ) )
159
+ . collect ( ) ;
160
+ paths. sort ( ) ;
161
+
162
+ assert_eq ! ( paths, vec![ "model.bin" , "model.safetensors" , "other.txt" ] ) ;
163
+
164
+ Ok ( ( ) )
165
+ }
166
+
167
+ #[ test]
168
+ fn test_get_paths_with_format_override ( ) -> anyhow:: Result < ( ) > {
169
+ let temp_dir = TempDir :: new ( ) ?;
170
+
171
+ // Create files with different extensions
172
+ File :: create ( temp_dir. path ( ) . join ( "model.custom" ) ) ?;
173
+ File :: create ( temp_dir. path ( ) . join ( "model.safetensors" ) ) ?;
174
+
175
+ let paths = get_paths_of_interest (
176
+ Some ( FileType :: SafeTensors ) ,
177
+ & temp_dir. path ( ) . join ( "model.custom" ) ,
178
+ ) ?;
179
+ assert_eq ! ( paths. len( ) , 1 ) ;
180
+ assert ! ( paths[ 0 ] . to_string_lossy( ) . ends_with( "model.custom" ) ) ;
181
+
182
+ Ok ( ( ) )
183
+ }
184
+
185
+ #[ test]
186
+ fn test_get_paths_sharded_files ( ) -> anyhow:: Result < ( ) > {
187
+ let temp_dir = TempDir :: new ( ) ?;
188
+
189
+ // Create sharded files
190
+ File :: create ( temp_dir. path ( ) . join ( "model-00001-of-00002.safetensors" ) ) ?;
191
+ File :: create ( temp_dir. path ( ) . join ( "model-00002-of-00002.safetensors" ) ) ?;
192
+ File :: create ( temp_dir. path ( ) . join ( "other.txt" ) ) ?;
193
+
194
+ let paths = get_paths_of_interest ( None , temp_dir. path ( ) ) ?;
195
+ assert_eq ! ( paths. len( ) , 3 ) ;
196
+
197
+ let mut paths: Vec < String > = paths
198
+ . iter ( )
199
+ . map ( |p| p. file_name ( ) . unwrap ( ) . to_string_lossy ( ) . to_string ( ) )
200
+ . collect ( ) ;
201
+ paths. sort ( ) ;
202
+
203
+ assert_eq ! (
204
+ paths,
205
+ vec![
206
+ "model-00001-of-00002.safetensors" ,
207
+ "model-00002-of-00002.safetensors" ,
208
+ "other.txt"
209
+ ]
210
+ ) ;
211
+
212
+ Ok ( ( ) )
213
+ }
214
+
215
+ #[ test]
216
+ fn test_get_paths_nonexistent ( ) {
217
+ let result = get_paths_of_interest ( None , & PathBuf :: from ( "/nonexistent/path" ) ) ;
218
+ assert ! ( result. is_err( ) ) ;
219
+ }
220
+
221
+ #[ test]
222
+ fn test_get_paths_nested_and_hidden ( ) -> anyhow:: Result < ( ) > {
223
+ let temp_dir = TempDir :: new ( ) ?;
224
+
225
+ // Create nested directory structure
226
+ let nested_dir = temp_dir. path ( ) . join ( "nested" ) ;
227
+ let deep_dir = nested_dir. join ( "deep" ) ;
228
+ std:: fs:: create_dir_all ( & deep_dir) ?;
229
+
230
+ // Create various files including hidden ones
231
+ File :: create ( temp_dir. path ( ) . join ( ".hidden" ) ) ?;
232
+ File :: create ( temp_dir. path ( ) . join ( "regular.txt" ) ) ?;
233
+ File :: create ( nested_dir. join ( ".hidden_nested" ) ) ?;
234
+ File :: create ( nested_dir. join ( "nested.bin" ) ) ?;
235
+ File :: create ( deep_dir. join ( ".very_hidden" ) ) ?;
236
+ File :: create ( deep_dir. join ( "deep.dat" ) ) ?;
237
+
238
+ let paths = get_paths_of_interest ( None , temp_dir. path ( ) ) ?;
239
+ assert_eq ! ( paths. len( ) , 6 ) ; // Should find all 6 files
240
+
241
+ let mut paths: Vec < String > = paths
242
+ . iter ( )
243
+ . map ( |p| p. file_name ( ) . unwrap ( ) . to_string_lossy ( ) . to_string ( ) )
244
+ . collect ( ) ;
245
+ paths. sort ( ) ;
246
+
247
+ assert_eq ! (
248
+ paths,
249
+ vec![
250
+ ".hidden" ,
251
+ ".hidden_nested" ,
252
+ ".very_hidden" ,
253
+ "deep.dat" ,
254
+ "nested.bin" ,
255
+ "regular.txt"
256
+ ]
257
+ ) ;
258
+
259
+ Ok ( ( ) )
260
+ }
261
+ }
0 commit comments