Skip to content

Commit 33a7bfd

Browse files
Merge branch 'cl/data' of https://github.com/FluxML/Flux.jl into cl/data
2 parents 69c1a4c + fbe1572 commit 33a7bfd

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

src/data/Data.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@ import SHA
1616

1717
deprecation_message() = @warn("Flux's datasets are deprecated, please use the package MLDatasets.jl")
1818

19-
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
19+
function deps(path...)
20+
if isnothing(@__DIR__) # sysimages
21+
joinpath("deps", path...)
22+
else
23+
joinpath(@__DIR__, "..", "..", "deps", path...)
24+
end
25+
end
2026

2127
function download_and_verify(url, path, hash)
2228
tmppath = tempname()
@@ -59,4 +65,4 @@ export Housing
5965

6066
#########################################
6167

62-
end#module
68+
end#module

src/data/fashion-mnist.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ module FashionMNIST
33
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
44
using ..Data: download_and_verify, deprecation_message
55

6-
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
6+
const dir = if isnothing(@__DIR__)
7+
joinpath("deps", "fashion-mnist")
8+
else
9+
joinpath(@__DIR__, "../../deps/fashion-mnist")
10+
end
711

812
function load()
913
mkpath(dir)
@@ -60,4 +64,4 @@ function labels(set = :train)
6064
[rawlabel(io) for _ = 1:N]
6165
end
6266

63-
end
67+
end

src/data/mnist.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ using ..Data: download_and_verify, deprecation_message
55

66
const Gray = Colors.Gray{Colors.N0f8}
77

8-
const dir = joinpath(@__DIR__, "../../deps/mnist")
8+
const dir = if isnothing(@__DIR__)
9+
joinpath("deps", "mnist")
10+
else
11+
joinpath(@__DIR__, "../../deps/mnist")
12+
end
913

1014
function gzopen(f, file)
1115
open(file) do io
@@ -110,4 +114,4 @@ function labels(set = :train)
110114
[rawlabel(io) for _ = 1:N]
111115
end
112116

113-
end # module
117+
end # module

0 commit comments

Comments
 (0)