Skip to content

Commit 28e16b1

Browse files
committed
Triage WriteBackendMethods into "unimplemented" (w/ better errors) vs optimize* (centralized, even if noop).
1 parent c3f51fd commit 28e16b1

File tree

4 files changed

+129
-85
lines changed

4 files changed

+129
-85
lines changed

crates/rustc_codegen_spirv/src/lib.rs

Lines changed: 93 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,16 @@ use rustc_session::Session;
167167
use rustc_session::config::{self, OutputFilenames, OutputType};
168168
use rustc_span::symbol::Symbol;
169169
use std::any::Any;
170-
use std::fs::{File, create_dir_all};
170+
use std::fs;
171171
use std::io::Cursor;
172172
use std::io::Write;
173173
use std::path::Path;
174174
use std::sync::Arc;
175175
use tracing::{error, warn};
176176

177177
fn dump_mir(tcx: TyCtxt<'_>, mono_items: &[(MonoItem<'_>, MonoItemData)], path: &Path) {
178-
create_dir_all(path.parent().unwrap()).unwrap();
179-
let mut file = File::create(path).unwrap();
178+
fs::create_dir_all(path.parent().unwrap()).unwrap();
179+
let mut file = fs::File::create(path).unwrap();
180180
for &(mono_item, _) in mono_items {
181181
if let MonoItem::Fn(instance) = mono_item
182182
&& matches!(instance.def, InstanceKind::Item(_))
@@ -189,27 +189,6 @@ fn dump_mir(tcx: TyCtxt<'_>, mono_items: &[(MonoItem<'_>, MonoItemData)], path:
189189
}
190190
}
191191

192-
// TODO: Should this store Vec or Module?
193-
struct SpirvModuleBuffer(Vec<u32>);
194-
195-
impl ModuleBufferMethods for SpirvModuleBuffer {
196-
fn data(&self) -> &[u8] {
197-
spirv_tools::binary::from_binary(&self.0)
198-
}
199-
}
200-
201-
// TODO: Should this store Vec or Module?
202-
struct SpirvThinBuffer(Vec<u32>);
203-
204-
impl ThinBufferMethods for SpirvThinBuffer {
205-
fn data(&self) -> &[u8] {
206-
spirv_tools::binary::from_binary(&self.0)
207-
}
208-
fn thin_link_data(&self) -> &[u8] {
209-
&[]
210-
}
211-
}
212-
213192
#[derive(Clone)]
214193
struct SpirvCodegenBackend;
215194

@@ -306,28 +285,76 @@ impl CodegenBackend for SpirvCodegenBackend {
306285
}
307286
}
308287

288+
struct SpirvModuleBuffer(Vec<u32>);
289+
290+
impl SpirvModuleBuffer {
291+
fn as_bytes(&self) -> &[u8] {
292+
spirv_tools::binary::from_binary(&self.0)
293+
}
294+
}
295+
impl ModuleBufferMethods for SpirvModuleBuffer {
296+
fn data(&self) -> &[u8] {
297+
self.as_bytes()
298+
}
299+
}
300+
impl ThinBufferMethods for SpirvModuleBuffer {
301+
fn data(&self) -> &[u8] {
302+
self.as_bytes()
303+
}
304+
fn thin_link_data(&self) -> &[u8] {
305+
&[]
306+
}
307+
}
308+
309+
impl SpirvCodegenBackend {
310+
fn optimize_common(
311+
_cgcx: &CodegenContext<Self>,
312+
_module: &mut ModuleCodegen<<Self as WriteBackendMethods>::Module>,
313+
) -> Result<(), FatalError> {
314+
// FIXME(eddyb) actually run as many optimization passes as possible,
315+
// before ever serializing `.spv` files that will later get linked.
316+
Ok(())
317+
}
318+
}
319+
309320
impl WriteBackendMethods for SpirvCodegenBackend {
310-
type Module = Vec<u32>;
321+
type Module = rspirv::dr::Module;
311322
type TargetMachine = ();
312323
type TargetMachineError = String;
313324
type ModuleBuffer = SpirvModuleBuffer;
314325
type ThinData = ();
315-
type ThinBuffer = SpirvThinBuffer;
326+
type ThinBuffer = SpirvModuleBuffer;
316327

328+
// FIXME(eddyb) reuse the "merge" stage of `crate::linker` for this, or even
329+
// delegate to `run_fat_lto` (although `-Zcombine-cgu` is much more niche).
317330
fn run_link(
318-
_cgcx: &CodegenContext<Self>,
319-
_diag_handler: DiagCtxtHandle<'_>,
331+
cgcx: &CodegenContext<Self>,
332+
diag_handler: DiagCtxtHandle<'_>,
320333
_modules: Vec<ModuleCodegen<Self::Module>>,
321334
) -> Result<ModuleCodegen<Self::Module>, FatalError> {
322-
todo!()
335+
assert!(
336+
cgcx.opts.unstable_opts.combine_cgu,
337+
"`run_link` (for `WorkItemResult::NeedsLink`) should \
338+
only be invoked due to `-Zcombine-cgu`"
339+
);
340+
diag_handler.fatal("Rust-GPU does not support `-Zcombine-cgu`")
323341
}
324342

343+
// FIXME(eddyb) reuse the "merge" stage of `crate::linker` for this, or even
344+
// consider setting `requires_lto = true` in the target specs and moving the
345+
// entirety of `crate::linker` into this stage (lacking diagnostics may be
346+
// an issue - it's surprising `CodegenBackend::link` has `Session` at all).
325347
fn run_fat_lto(
326-
_: &CodegenContext<Self>,
327-
_: Vec<FatLtoInput<Self>>,
328-
_: Vec<(SerializedModule<Self::ModuleBuffer>, WorkProduct)>,
348+
cgcx: &CodegenContext<Self>,
349+
_modules: Vec<FatLtoInput<Self>>,
350+
_cached_modules: Vec<(SerializedModule<Self::ModuleBuffer>, WorkProduct)>,
329351
) -> Result<LtoModuleCodegen<Self>, FatalError> {
330-
todo!()
352+
assert!(
353+
cgcx.lto == rustc_session::config::Lto::Fat,
354+
"`run_fat_lto` (for `WorkItemResult::NeedsFatLto`) should \
355+
only be invoked due to `-Clto` (or equivalent)"
356+
);
357+
unreachable!("Rust-GPU does not support fat LTO")
331358
}
332359

333360
fn run_thin_lto(
@@ -347,35 +374,39 @@ impl WriteBackendMethods for SpirvCodegenBackend {
347374
}
348375

349376
fn optimize(
350-
_: &CodegenContext<Self>,
351-
_: DiagCtxtHandle<'_>,
352-
_: &mut ModuleCodegen<Self::Module>,
353-
_: &ModuleConfig,
377+
cgcx: &CodegenContext<Self>,
378+
_dcx: DiagCtxtHandle<'_>,
379+
module: &mut ModuleCodegen<Self::Module>,
380+
_config: &ModuleConfig,
354381
) -> Result<(), FatalError> {
355-
// TODO: Implement
356-
Ok(())
382+
Self::optimize_common(cgcx, module)
357383
}
358384

359385
fn optimize_thin(
360-
_cgcx: &CodegenContext<Self>,
386+
cgcx: &CodegenContext<Self>,
361387
thin_module: ThinModule<Self>,
362388
) -> Result<ModuleCodegen<Self::Module>, FatalError> {
363-
let module = ModuleCodegen {
364-
module_llvm: spirv_tools::binary::to_binary(thin_module.data())
365-
.unwrap()
366-
.to_vec(),
389+
// FIXME(eddyb) the inefficiency of Module -> [u8] -> Module roundtrips
390+
// comes from upstream and it applies to `rustc_codegen_llvm` as well,
391+
// eventually it should be properly addressed (for `ThinLocal` at least).
392+
let mut module = ModuleCodegen {
393+
module_llvm: link::with_rspirv_loader(|loader| {
394+
rspirv::binary::parse_bytes(thin_module.data(), loader)
395+
})
396+
.unwrap(),
367397
name: thin_module.name().to_string(),
368398
kind: ModuleKind::Regular,
369399
thin_lto_buffer: None,
370400
};
401+
Self::optimize_common(cgcx, &mut module)?;
371402
Ok(module)
372403
}
373404

374405
fn optimize_fat(
375-
_: &CodegenContext<Self>,
376-
_: &mut ModuleCodegen<Self::Module>,
406+
cgcx: &CodegenContext<Self>,
407+
module: &mut ModuleCodegen<Self::Module>,
377408
) -> Result<(), FatalError> {
378-
todo!()
409+
Self::optimize_common(cgcx, module)
379410
}
380411

381412
fn codegen(
@@ -384,20 +415,19 @@ impl WriteBackendMethods for SpirvCodegenBackend {
384415
module: ModuleCodegen<Self::Module>,
385416
_config: &ModuleConfig,
386417
) -> Result<CompiledModule, FatalError> {
418+
let kind = module.kind;
419+
let (name, module_buffer) = Self::serialize_module(module);
420+
387421
let path = cgcx.output_filenames.temp_path_for_cgu(
388422
OutputType::Object,
389-
&module.name,
423+
&name,
390424
cgcx.invocation_temp.as_deref(),
391425
);
392-
// Note: endianness doesn't matter, readers deduce endianness from magic header.
393-
let spirv_module = spirv_tools::binary::from_binary(&module.module_llvm);
394-
File::create(&path)
395-
.unwrap()
396-
.write_all(spirv_module)
397-
.unwrap();
426+
fs::write(&path, module_buffer.as_bytes()).unwrap();
427+
398428
Ok(CompiledModule {
399-
name: module.name,
400-
kind: module.kind,
429+
name,
430+
kind,
401431
object: Some(path),
402432
dwarf_object: None,
403433
bytecode: None,
@@ -411,11 +441,14 @@ impl WriteBackendMethods for SpirvCodegenBackend {
411441
module: ModuleCodegen<Self::Module>,
412442
_want_summary: bool,
413443
) -> (String, Self::ThinBuffer) {
414-
(module.name, SpirvThinBuffer(module.module_llvm))
444+
Self::serialize_module(module)
415445
}
416446

417447
fn serialize_module(module: ModuleCodegen<Self::Module>) -> (String, Self::ModuleBuffer) {
418-
(module.name, SpirvModuleBuffer(module.module_llvm))
448+
(
449+
module.name,
450+
SpirvModuleBuffer(module.module_llvm.assemble()),
451+
)
419452
}
420453

421454
fn autodiff(
@@ -424,7 +457,7 @@ impl WriteBackendMethods for SpirvCodegenBackend {
424457
_diff_fncs: Vec<AutoDiffItem>,
425458
_config: &ModuleConfig,
426459
) -> Result<(), FatalError> {
427-
todo!()
460+
unreachable!("Rust-GPU does not support autodiff")
428461
}
429462
}
430463

@@ -490,12 +523,11 @@ impl ExtraBackendMethods for SpirvCodegenBackend {
490523
} else {
491524
with_no_trimmed_paths!(do_codegen(&mut cx));
492525
}
493-
let spirv_module = cx.finalize_module().assemble();
494526

495527
(
496528
ModuleCodegen {
497529
name: cgu_name.to_string(),
498-
module_llvm: spirv_module,
530+
module_llvm: cx.finalize_module(),
499531
kind: ModuleKind::Regular,
500532
thin_lto_buffer: None,
501533
},

crates/rustc_codegen_spirv/src/link.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
33

44
use crate::codegen_cx::{CodegenArgs, SpirvMetadata};
5-
use crate::{SpirvCodegenBackend, SpirvModuleBuffer, SpirvThinBuffer, linker};
5+
use crate::{SpirvCodegenBackend, SpirvModuleBuffer, linker};
66
use ar::{Archive, GnuBuilder, Header};
77
use rspirv::binary::Assemble;
88
use rspirv::dr::Module;
@@ -548,6 +548,16 @@ fn create_archive(files: &[&Path], metadata: &[u8], out_filename: &Path) {
548548
builder.into_inner().unwrap();
549549
}
550550

551+
// HACK(eddyb) hiding the actual implementation to avoid `rspirv::dr::Loader`
552+
// being hardcoded (as future work may need to customize it for various reasons).
553+
pub fn with_rspirv_loader<E>(
554+
f: impl FnOnce(&mut dyn rspirv::binary::Consumer) -> Result<(), E>,
555+
) -> Result<rspirv::dr::Module, E> {
556+
let mut loader = rspirv::dr::Loader::new();
557+
f(&mut loader)?;
558+
Ok(loader.module())
559+
}
560+
551561
/// This is the actual guts of linking: the rest of the link-related functions are just digging through rustc's
552562
/// shenanigans to collect all the object files we need to link.
553563
fn do_link(
@@ -562,11 +572,8 @@ fn do_link(
562572

563573
let mut modules = Vec::new();
564574
let mut add_module = |file_name: &OsStr, bytes: &[u8]| {
565-
let module = {
566-
let mut loader = rspirv::dr::Loader::new();
567-
rspirv::binary::parse_bytes(bytes, &mut loader).unwrap();
568-
loader.module()
569-
};
575+
let module =
576+
with_rspirv_loader(|loader| rspirv::binary::parse_bytes(bytes, loader)).unwrap();
570577
if let Some(dir) = &cg_args.dump_pre_link {
571578
// FIXME(eddyb) is it a good idea to re-`assemble` the `rspirv::dr`
572579
// module, or should this just save the original bytes?
@@ -625,7 +632,7 @@ fn do_link(
625632
// TODO: WorkProduct impl
626633
pub(crate) fn run_thin(
627634
cgcx: &CodegenContext<SpirvCodegenBackend>,
628-
modules: Vec<(String, SpirvThinBuffer)>,
635+
modules: Vec<(String, SpirvModuleBuffer)>,
629636
cached_modules: Vec<(SerializedModule<SpirvModuleBuffer>, WorkProduct)>,
630637
) -> Result<(Vec<LtoModuleCodegen<SpirvCodegenBackend>>, Vec<WorkProduct>), FatalError> {
631638
if cgcx.opts.cg.linker_plugin_lto.enabled() {

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ use crate::codegen_cx::{ModuleOutputType, SpirvMetadata};
2222
use crate::custom_decorations::{CustomDecoration, SrcLocDecoration, ZombieDecoration};
2323
use crate::custom_insts;
2424
use either::Either;
25-
use rspirv::binary::{Assemble, Consumer};
26-
use rspirv::dr::{Block, Loader, Module, ModuleHeader, Operand};
25+
use rspirv::binary::Assemble;
26+
use rspirv::dr::{Block, Module, ModuleHeader, Operand};
2727
use rspirv::spirv::{Op, StorageClass, Word};
2828
use rustc_data_structures::fx::FxHashMap;
2929
use rustc_errors::ErrorGuaranteed;
@@ -255,15 +255,21 @@ pub fn link(
255255
}
256256

257257
// merge the binaries
258-
let mut loader = Loader::new();
259-
260-
for module in inputs {
261-
module.all_inst_iter().for_each(|inst| {
262-
loader.consume_instruction(inst.clone());
263-
});
264-
}
258+
let mut output = crate::link::with_rspirv_loader(|loader| {
259+
for module in inputs {
260+
for inst in module.all_inst_iter() {
261+
use rspirv::binary::ParseAction;
262+
match loader.consume_instruction(inst.clone()) {
263+
ParseAction::Continue => {}
264+
ParseAction::Stop => unreachable!(),
265+
ParseAction::Error(err) => return Err(err),
266+
}
267+
}
268+
}
269+
Ok(())
270+
})
271+
.unwrap();
265272

266-
let mut output = loader.module();
267273
let mut header = ModuleHeader::new(bound + 1);
268274
header.set_version(version.0, version.1);
269275
header.generator = 0x001B_0000;
@@ -583,9 +589,10 @@ pub fn link(
583589
// FIXME(eddyb) dump both SPIR-T and `spv_words` if there's an error here.
584590
output = {
585591
let _timer = sess.timer("parse-spv_words-from-spirt");
586-
let mut loader = Loader::new();
587-
rspirv::binary::parse_words(&spv_words, &mut loader).unwrap();
588-
loader.module()
592+
crate::link::with_rspirv_loader(|loader| {
593+
rspirv::binary::parse_words(&spv_words, loader)
594+
})
595+
.unwrap()
589596
};
590597
}
591598

crates/rustc_codegen_spirv/src/linker/test.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::{LinkResult, link};
2-
use rspirv::dr::{Loader, Module};
2+
use rspirv::dr::Module;
33
use rustc_errors::registry::Registry;
44
use rustc_session::CompilerIO;
55
use rustc_session::config::{Input, OutputFilenames, OutputTypes};
@@ -59,9 +59,7 @@ fn validate(spirv: &[u32]) {
5959
}
6060

6161
fn load(bytes: &[u8]) -> Module {
62-
let mut loader = Loader::new();
63-
rspirv::binary::parse_bytes(bytes, &mut loader).unwrap();
64-
loader.module()
62+
crate::link::with_rspirv_loader(|loader| rspirv::binary::parse_bytes(bytes, loader)).unwrap()
6563
}
6664

6765
// FIXME(eddyb) shouldn't this be named just `link`? (`assemble_spirv` is separate)

0 commit comments

Comments
 (0)