@@ -116,7 +116,7 @@ class ASFileSystem : public FileSystem {
116116
117117 Status DownloadFolder (
118118 const std::string& container, const std::string& path,
119- const std::string& dest);
119+ const std::string& dest, const bool recursive );
120120
121121 std::shared_ptr<asb::BlobServiceClient> client_;
122122 re2::RE2 as_regex_;
@@ -392,7 +392,7 @@ ASFileSystem::FileExists(const std::string& path, bool* exists)
392392Status
393393ASFileSystem::DownloadFolder (
394394 const std::string& container, const std::string& path,
395- const std::string& dest)
395+ const std::string& dest, const bool recursive )
396396{
397397 auto container_client = client_->GetBlobContainerClient (container);
398398 auto func = [&](const std::vector<asb::Models::BlobItem>& blobs,
@@ -408,17 +408,20 @@ ASFileSystem::DownloadFolder(
408408 " Failed to download file at " + blob_item.Name + " :" + ex.what ());
409409 }
410410 }
411- for (const auto & directory_item : blob_prefixes) {
412- const auto & local_path = JoinPath ({dest, BaseName (directory_item)});
413- int status = mkdir (
414- const_cast <char *>(local_path.c_str ()), S_IRUSR | S_IWUSR | S_IXUSR);
415- if (status == -1 ) {
416- return Status (
417- Status::Code::INTERNAL,
418- " Failed to create local folder: " + local_path +
419- " , errno:" + strerror (errno));
411+ if (recursive) {
412+ for (const auto & directory_item : blob_prefixes) {
413+ const auto & local_path = JoinPath ({dest, BaseName (directory_item)});
414+ int status = mkdir (
415+ const_cast <char *>(local_path.c_str ()), S_IRUSR | S_IWUSR | S_IXUSR);
416+ if (status == -1 && errno != EEXIST) {
417+ return Status (
418+ Status::Code::INTERNAL,
419+ " Failed to create local folder: " + local_path +
420+ " , errno:" + strerror (errno));
421+ }
422+ RETURN_IF_ERROR (
423+ DownloadFolder (container, directory_item, local_path, recursive));
420424 }
421- RETURN_IF_ERROR (DownloadFolder (container, directory_item, local_path));
422425 }
423426 return Status::Success;
424427 };
@@ -445,21 +448,30 @@ ASFileSystem::LocalizePath(
445448 " AS file localization not yet implemented " + path);
446449 }
447450
448- std::string folder_template = " /tmp/folderXXXXXX" ;
449- char * tmp_folder = mkdtemp (const_cast <char *>(folder_template.c_str ()));
450- if (tmp_folder == nullptr ) {
451- return Status (
452- Status::Code::INTERNAL,
453- " Failed to create local temp folder: " + folder_template +
454- " , errno:" + strerror (errno));
451+ // Create a local directory for s3 model store.
452+ // If `mount_dir` or ENV variable are not set,
453+ // creates a temporary directory under `/tmp` with the format: "folderXXXXXX".
454+ // Otherwise, will create a folder under specified directory with the name
455+ // indicated in path (i.e. everything after the last encounter of `/`).
456+ const char * env_mount_dir = std::getenv (" TRITON_AZURE_MOUNT_DIRECTORY" );
457+ std::string tmp_folder;
458+ if (mount_dir.empty () && env_mount_dir == nullptr ) {
459+ RETURN_IF_ERROR (triton::core::MakeTemporaryDirectory (
460+ FileSystemType::LOCAL, &tmp_folder));
461+ } else {
462+ tmp_folder = mount_dir.empty () ? std::string (env_mount_dir) : mount_dir;
463+ tmp_folder =
464+ JoinPath ({tmp_folder, path.substr (path.find_last_of (' /' ) + 1 )});
465+ RETURN_IF_ERROR (triton::core::MakeDirectory (
466+ tmp_folder, true /* recursive*/ , true /* allow_dir_exist*/ ));
455467 }
456- localized->reset (new LocalizedPath (path, tmp_folder));
457468
458- std::string dest (folder_template );
469+ localized-> reset ( new LocalizedPath (path, tmp_folder) );
459470
471+ std::string dest (tmp_folder);
460472 std::string container, blob;
461473 RETURN_IF_ERROR (ParsePath (path, &container, &blob));
462- return DownloadFolder (container, blob, dest);
474+ return DownloadFolder (container, blob, dest, recursive );
463475}
464476
465477Status
0 commit comments