Skip to content

Commit 9595823

Browse files
efaulhaberranocha
andauthored
Add macro trixi_include_changeprecision to make a double precision elixir run with single precision (#35)
* Make `trixi_include` more flexible by allowing a mapping to be passed * Bump version * Update docstring of `trixi_include` * Add `trixi_include_changeprecision` * Export new macro * Fix docs * Update src/trixi_include.jl Co-authored-by: Erik Faulhaber <44124897+efaulhaber@users.noreply.github.com> * Add tests * Reformat code * Fix tests on Windows * Cover missing function * Use `mktemp` for temporary files * Reformat * Update src/trixi_include.jl --------- Co-authored-by: Hendrik Ranocha <ranocha@users.noreply.github.com>
1 parent 882537a commit 9595823

File tree

4 files changed

+173
-61
lines changed

4 files changed

+173
-61
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "TrixiBase"
22
uuid = "9a0f1c46-06d5-4909-a5a3-ce25d3fa3284"
33
authors = ["Michael Schlottke-Lakemper <michael@sloede.com>"]
4-
version = "0.1.5-DEV"
4+
version = "0.1.5"
55

66
[deps]
7+
ChangePrecision = "3cb15238-376d-56a3-8042-d33272777c9a"
78
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
89

910
[weakdeps]
@@ -13,6 +14,7 @@ MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
1314
TrixiBaseMPIExt = "MPI"
1415

1516
[compat]
17+
ChangePrecision = "1.1.0"
1618
MPI = "0.20"
1719
TimerOutputs = "0.5.25"
1820
julia = "1.8"

src/TrixiBase.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
module TrixiBase
22

3+
using ChangePrecision: ChangePrecision
34
using TimerOutputs: TimerOutput, TimerOutputs
45

56
include("trixi_include.jl")
67
include("trixi_timeit.jl")
78

8-
export trixi_include
9+
export trixi_include, trixi_include_changeprecision
910
export @trixi_timeit, timer, timeit_debug_enabled,
1011
disable_debug_timings, enable_debug_timings
1112

src/trixi_include.jl

+57-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# of `TrixiBase`. However, users will want to evaluate in the global scope of `Main` or something
44
# similar to manage dependencies on their own.
55
"""
6-
trixi_include([mod::Module=Main,] elixir::AbstractString; kwargs...)
6+
trixi_include([mapexpr::Function=identity,] [mod::Module=Main,] elixir::AbstractString; kwargs...)
77
88
`include` the file `elixir` and evaluate its content in the global scope of module `mod`.
99
You can override specific assignments in `elixir` by supplying keyword arguments.
@@ -16,6 +16,10 @@ into calls to `solve` with it's default value used in the SciML ecosystem
1616
for ODEs, see the "Miscellaneous" section of the
1717
[documentation](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/).
1818
19+
The optional first argument `mapexpr` can be used to transform the included code before
20+
it is evaluated: for each parsed expression `expr` in `elixir`, the `include` function
21+
actually evaluates `mapexpr(expr)`. If it is omitted, `mapexpr` defaults to `identity`.
22+
1923
# Examples
2024
2125
```@example
@@ -30,7 +34,7 @@ julia> redirect_stdout(devnull) do
3034
0.1
3135
```
3236
"""
33-
function trixi_include(mod::Module, elixir::AbstractString; kwargs...)
37+
function trixi_include(mapexpr::Function, mod::Module, elixir::AbstractString; kwargs...)
3438
# Check that all kwargs exist as assignments
3539
code = read(elixir, String)
3640
expr = Meta.parse("begin \n$code \nend")
@@ -45,13 +49,63 @@ function trixi_include(mod::Module, elixir::AbstractString; kwargs...)
4549
if !mpi_isparallel(Val{:MPIExt}())
4650
@info "You just called `trixi_include`. Julia may now compile the code, please be patient."
4751
end
48-
Base.include(ex -> replace_assignments(insert_maxiters(ex); kwargs...), mod, elixir)
52+
Base.include(ex -> mapexpr(replace_assignments(insert_maxiters(ex); kwargs...)),
53+
mod, elixir)
54+
end
55+
56+
function trixi_include(mod::Module, elixir::AbstractString; kwargs...)
57+
trixi_include(identity, mod, elixir; kwargs...)
4958
end
5059

5160
function trixi_include(elixir::AbstractString; kwargs...)
5261
trixi_include(Main, elixir; kwargs...)
5362
end
5463

64+
"""
65+
trixi_include_changeprecision(T, [mod::Module=Main,] elixir::AbstractString; kwargs...)
66+
67+
`include` the elixir `elixir` and evaluate its content in the global scope of module `mod`.
68+
You can override specific assignments in `elixir` by supplying keyword arguments,
69+
similar to [`trixi_include`](@ref).
70+
71+
The only difference to [`trixi_include`](@ref) is that the precision of floating-point
72+
numbers in the included elixir is changed to `T`.
73+
More precisely, the package [ChangePrecision.jl](https://github.com/JuliaMath/ChangePrecision.jl)
74+
is used to convert all `Float64` literals, operations like `/` that produce `Float64` results,
75+
and functions like `ones` that return `Float64` arrays by default, to the desired type `T`.
76+
See the documentation of ChangePrecision.jl for more details.
77+
78+
The purpose of this function is to conveniently run a full simulation with `Float32`,
79+
which is orders of magnitude faster on most GPUs than `Float64`, by just including
80+
the elixir with `trixi_include_changeprecision(Float32, elixir)`.
81+
Many constructors in the Trixi.jl framework are written in a way that changing all floating-point
82+
arguments to `Float32` will change the element type to `Float32` as well.
83+
In TrixiParticles.jl, including an elixir with this macro should be sufficient
84+
to run the full simulation with single precision.
85+
"""
86+
function trixi_include_changeprecision(T, mod::Module, filename::AbstractString; kwargs...)
87+
trixi_include(expr -> ChangePrecision.changeprecision(T, replace_trixi_include(T, expr)),
88+
mod, filename; kwargs...)
89+
end
90+
91+
function trixi_include_changeprecision(T, filename::AbstractString; kwargs...)
92+
trixi_include_changeprecision(T, Main, filename; kwargs...)
93+
end
94+
95+
function replace_trixi_include(T, expr)
96+
expr = TrixiBase.walkexpr(expr) do x
97+
if x isa Expr
98+
if x.head === :call && x.args[1] === :trixi_include
99+
x.args[1] = :trixi_include_changeprecision
100+
insert!(x.args, 2, :($T))
101+
end
102+
end
103+
return x
104+
end
105+
106+
return expr
107+
end
108+
55109
# Insert the keyword argument `maxiters` into calls to `solve` and `Trixi.solve`
56110
# with default value `10^5` if it is not already present.
57111
function insert_maxiters(expr)

test/trixi_include.jl

+111-56
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,27 @@
44
x = 4
55
"""
66

7-
filename = tempname()
8-
try
9-
open(filename, "w") do file
10-
write(file, example)
11-
end
7+
mktemp() do path, io
8+
write(io, example)
9+
close(io)
1210

1311
# Use `@trixi_testset`, which wraps code in a temporary module, and call
1412
# `trixi_include` with `@__MODULE__` in order to isolate this test.
15-
@test_nowarn_mod trixi_include(@__MODULE__, filename)
13+
@test_nowarn_mod trixi_include(@__MODULE__, path)
1614
@test @isdefined x
1715
@test x == 4
1816

19-
@test_nowarn_mod trixi_include(@__MODULE__, filename, x = 7)
17+
@test_nowarn_mod trixi_include(@__MODULE__, path, x = 7)
2018

2119
@test x == 7
2220

2321
# Verify default version (that includes in `Main`)
24-
@test_nowarn_mod trixi_include(filename, x = 11)
22+
@test_nowarn_mod trixi_include(path, x = 11)
2523
@test Main.x == 11
2624

2725
@test_throws "assignment `y` not found in expression" trixi_include(@__MODULE__,
28-
filename,
26+
path,
2927
y = 3)
30-
finally
31-
rm(filename, force = true)
3228
end
3329
end
3430

@@ -40,22 +36,18 @@
4036
x = solve()
4137
"""
4238

43-
filename = tempname()
44-
try
45-
open(filename, "w") do file
46-
write(file, example)
47-
end
39+
mktemp() do path, io
40+
write(io, example)
41+
close(io)
4842

4943
# Use `@trixi_testset`, which wraps code in a temporary module, and call
5044
# `trixi_include` with `@__MODULE__` in order to isolate this test.
5145
@test_throws "no method matching solve(; maxiters" trixi_include(@__MODULE__,
52-
filename)
46+
path)
5347

5448
@test_throws "no method matching solve(; maxiters" trixi_include(@__MODULE__,
55-
filename,
49+
path,
5650
maxiters = 3)
57-
finally
58-
rm(filename, force = true)
5951
end
6052
end
6153

@@ -81,49 +73,112 @@
8173
y = solve(; maxiters=0)
8274
"""
8375

84-
filename1 = tempname()
85-
filename2 = tempname()
86-
filename3 = tempname()
87-
filename4 = tempname()
88-
try
89-
open(filename1, "w") do file
90-
write(file, example1)
91-
end
92-
open(filename2, "w") do file
93-
write(file, example2)
94-
end
95-
open(filename3, "w") do file
96-
write(file, example3)
97-
end
98-
open(filename4, "w") do file
99-
write(file, example4)
76+
mktemp() do path1, io1
77+
write(io1, example1)
78+
close(io1)
79+
80+
mktemp() do path2, io2
81+
write(io2, example2)
82+
close(io2)
83+
84+
mktemp() do path3, io3
85+
write(io3, example3)
86+
close(io3)
87+
88+
mktemp() do path4, io4
89+
write(io4, example4)
90+
close(io4)
91+
92+
# Use `@trixi_testset`, which wraps code in a temporary module,
93+
# and call `Base.include` and `trixi_include` with `@__MODULE__`
94+
# in order to isolate this test.
95+
Base.include(@__MODULE__, path1)
96+
@test_nowarn_mod trixi_include(@__MODULE__, path2)
97+
@test @isdefined x
98+
# This is the default `maxiters` inserted by `trixi_include`
99+
@test x == 10^5
100+
101+
@test_nowarn_mod trixi_include(@__MODULE__, path2, maxiters = 7)
102+
# Test that `maxiters` got overwritten
103+
@test x == 7
104+
105+
# Verify that existing `maxiters` is added exactly once in the
106+
# following cases:
107+
# case 1) `maxiters` is *before* semicolon in included file
108+
@test_nowarn_mod trixi_include(@__MODULE__, path3, maxiters = 11)
109+
@test y == 11
110+
# case 2) `maxiters` is *after* semicolon in included file
111+
@test_nowarn_mod trixi_include(@__MODULE__, path3, maxiters = 14)
112+
@test y == 14
113+
end
114+
end
100115
end
116+
end
117+
end
118+
end
119+
120+
@trixi_testset "`trixi_include_changeprecision`" begin
121+
@trixi_testset "Basic" begin
122+
example = """
123+
x = 4.0
124+
y = zeros(3)
125+
"""
126+
127+
mktemp() do path, io
128+
write(io, example)
129+
close(io)
101130

102131
# Use `@trixi_testset`, which wraps code in a temporary module, and call
103-
# `Base.include` and `trixi_include` with `@__MODULE__` in order to isolate this test.
104-
Base.include(@__MODULE__, filename1)
105-
@test_nowarn_mod trixi_include(@__MODULE__, filename2)
132+
# `trixi_include_changeprecision` with `@__MODULE__` in order to isolate this test.
133+
@test_nowarn_mod trixi_include_changeprecision(Float32, @__MODULE__, path)
106134
@test @isdefined x
107-
# This is the default `maxiters` inserted by `trixi_include`
108-
@test x == 10^5
135+
@test x == 4
136+
@test typeof(x) == Float32
137+
@test @isdefined y
138+
@test eltype(y) == Float32
139+
140+
# Manually overwritten assignments are also changed
141+
@test_nowarn_mod trixi_include_changeprecision(Float32, @__MODULE__, path,
142+
x = 7.0)
109143

110-
@test_nowarn_mod trixi_include(@__MODULE__, filename2,
111-
maxiters = 7)
112-
# Test that `maxiters` got overwritten
113144
@test x == 7
145+
@test typeof(x) == Float32
146+
147+
# Verify default version (that includes in `Main`)
148+
@test_nowarn_mod trixi_include_changeprecision(Float32, path, x = 11.0)
149+
@test Main.x == 11
150+
@test typeof(Main.x) == Float32
151+
end
152+
end
153+
154+
@trixi_testset "Recursive" begin
155+
example1 = """
156+
x = 4.0
157+
y = zeros(3)
158+
"""
114159

115-
# Verify that adding `maxiters` to `maxiters` results in exactly one of them
116-
# case 1) `maxiters` is *before* semicolon in included file
117-
@test_nowarn_mod trixi_include(@__MODULE__, filename3, maxiters = 11)
118-
@test y == 11
119-
# case 2) `maxiters` is *after* semicolon in included file
120-
@test_nowarn_mod trixi_include(@__MODULE__, filename3, maxiters = 14)
121-
@test y == 14
122-
finally
123-
rm(filename1, force = true)
124-
rm(filename2, force = true)
125-
rm(filename3, force = true)
126-
rm(filename4, force = true)
160+
mktemp() do path1, io1
161+
write(io1, example1)
162+
close(io1)
163+
164+
# Use raw string to allow backslashes in Windows paths
165+
example2 = """
166+
trixi_include(@__MODULE__, raw"$path1", x = 7.0)
167+
"""
168+
169+
mktemp() do path2, io2
170+
write(io2, example2)
171+
close(io2)
172+
173+
# Use `@trixi_testset`, which wraps code in a temporary module, and call
174+
# `trixi_include_changeprecision` with `@__MODULE__` in order to isolate this test.
175+
@test_nowarn_mod trixi_include_changeprecision(Float32, @__MODULE__, path2)
176+
@test @isdefined x
177+
@test x == 7
178+
@test typeof(x) == Float32
179+
@test @isdefined y
180+
@test eltype(y) == Float32
181+
end
127182
end
128183
end
129184
end

0 commit comments

Comments
 (0)