Skip to content

Commit bfc93e2

Browse files
authored
In-memory export of VCF info flags (#613)
1 parent bb322b3 commit bfc93e2

File tree

64 files changed

+152
-5
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+152
-5
lines changed

apis/python/tests/test_tiledbvcf.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,3 +1503,24 @@ def test_gvcf_export(tmp_path):
15031503
for test in tests:
15041504
df = ds.read(attrs=attrs, regions=test["region"], set_af_filter="<=1.0")
15051505
assert set(df["sample_name"].unique()) == set(test["samples"])
1506+
1507+
1508+
def test_flag_export(tmp_path):
1509+
# Create the dataset
1510+
uri = os.path.join(tmp_path, "dataset")
1511+
ds = tiledbvcf.Dataset(uri, mode="w")
1512+
samples = [os.path.join(TESTS_INPUT_DIR, s) for s in ["small.vcf.gz"]]
1513+
ds.create_dataset()
1514+
ds.ingest_samples(samples)
1515+
1516+
# Read info flags
1517+
ds = tiledbvcf.Dataset(uri, mode="r")
1518+
df = ds.read(attrs=["pos_start", "info_DB", "info_DS"])
1519+
df = df.sort_values(by=["pos_start"])
1520+
1521+
# Check if flags match the expected values
1522+
expected_db = [1, 1, 1, 0, 0, 1]
1523+
assert df["info_DB"].tolist() == expected_db
1524+
1525+
expected_ds = [1, 1, 0, 0, 1, 1]
1526+
assert df["info_DS"].tolist() == expected_ds

apis/spark/src/test/java/io/tiledb/vcf/VCFDatasourceTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,4 +873,36 @@ public void testTimeTravel() throws java.text.ParseException {
873873
Assert.assertEquals(rows.get(1).getString(0), "HG01762");
874874
Assert.assertEquals(rows.get(2).getString(0), "HG01762");
875875
}
876+
877+
@Test
878+
public void testInfoFlagExport() {
879+
Dataset<Row> dfRead =
880+
session()
881+
.read()
882+
.format("io.tiledb.vcf")
883+
.option("uri", testSampleGroupURI("small.tdb", "v4"))
884+
.option("ranges", "1:10000-20000")
885+
.option("tiledb.vfs.num_threads", 1)
886+
.option("tiledb.vcf.log_level", "DEBUG")
887+
.load();
888+
Dataset<Row> df = dfRead.select("contig", "posStart", "info_DB", "info_DS");
889+
df.show();
890+
891+
// Define expected values
892+
List<Integer> expectedDB = Arrays.asList(1, 1, 1, 0, 0, 1);
893+
List<Integer> expectedDS = Arrays.asList(1, 1, 0, 0, 1, 1);
894+
895+
// Get actual values
896+
List<Row> rows = df.collectAsList();
897+
List<Integer> actualDB = new ArrayList<>();
898+
List<Integer> actualDS = new ArrayList<>();
899+
for (Row row : rows) {
900+
actualDB.add(row.getInt(2));
901+
actualDS.add(row.getInt(3));
902+
}
903+
904+
// Compare
905+
Assert.assertEquals(expectedDB, actualDB);
906+
Assert.assertEquals(expectedDS, actualDS);
907+
}
876908
}

apis/spark3/src/test/java/io/tiledb/vcf/VCFDatasourceTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,4 +873,36 @@ public void testTimeTravel() throws java.text.ParseException {
873873
Assert.assertEquals(rows.get(1).getString(0), "HG01762");
874874
Assert.assertEquals(rows.get(2).getString(0), "HG01762");
875875
}
876+
877+
@Test
878+
public void testInfoFlagExport() {
879+
Dataset<Row> dfRead =
880+
session()
881+
.read()
882+
.format("io.tiledb.vcf")
883+
.option("uri", testSampleGroupURI("small.tdb", "v4"))
884+
.option("ranges", "1:10000-20000")
885+
.option("tiledb.vfs.num_threads", 1)
886+
.option("tiledb.vcf.log_level", "DEBUG")
887+
.load();
888+
Dataset<Row> df = dfRead.select("contig", "posStart", "info_DB", "info_DS");
889+
df.show();
890+
891+
// Define expected values
892+
List<Integer> expectedDB = Arrays.asList(1, 1, 1, 0, 0, 1);
893+
List<Integer> expectedDS = Arrays.asList(1, 1, 0, 0, 1, 1);
894+
895+
// Get actual values
896+
List<Row> rows = df.collectAsList();
897+
List<Integer> actualDB = new ArrayList<>();
898+
List<Integer> actualDS = new ArrayList<>();
899+
for (Row row : rows) {
900+
actualDB.add(row.getInt(2));
901+
actualDS.add(row.getInt(3));
902+
}
903+
904+
// Compare
905+
Assert.assertEquals(expectedDB, actualDB);
906+
Assert.assertEquals(expectedDS, actualDS);
907+
}
876908
}

libtiledbvcf/src/c_api/tiledbvcf_enum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ TILEDB_VCF_READ_STATUS_ENUM(FAILED) = 0,
4242
TILEDB_VCF_ATTR_DATATYPE_ENUM(INT32) = 2,
4343
/** 32-bit floating-point */
4444
TILEDB_VCF_ATTR_DATATYPE_ENUM(FLOAT32) = 3,
45+
/** Info Flag */
46+
TILEDB_VCF_ATTR_DATATYPE_ENUM(FLAG) = 4,
4547
#endif
4648

4749
#ifdef TILEDB_VCF_CHECKSUM_TYPE_ENUM

libtiledbvcf/src/read/in_memory_exporter.cc

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,27 @@
3030
namespace tiledb {
3131
namespace vcf {
3232

33+
// Declare static variables
34+
std::set<std::string> InMemoryExporter::info_flags_;
35+
36+
void InMemoryExporter::add_flag(const std::string& field) {
37+
// Add info_ to field if it is missing
38+
std::string info_field = field;
39+
if (field.substr(0, 5) != "info_") {
40+
info_field = "info_" + field;
41+
}
42+
info_flags_.insert(info_field);
43+
}
44+
45+
bool InMemoryExporter::is_flag(const std::string& field) {
46+
// Add info_ to field if it is missing
47+
std::string info_field = field;
48+
if (field.substr(0, 5) != "info_") {
49+
info_field = "info_" + field;
50+
}
51+
return info_flags_.count(info_field) > 0;
52+
}
53+
3354
void InMemoryExporter::set_buffer_values(
3455
const std::string& attribute, void* buff, int64_t buff_size) {
3556
if (buff == nullptr) {
@@ -644,6 +665,18 @@ void InMemoryExporter::attribute_datatype(
644665
break;
645666
}
646667

668+
if (*datatype == AttrDatatype::FLAG) {
669+
// Add this attribute to the set of info flags.
670+
add_flag(attribute);
671+
672+
// Set buffer attributes for flags and use INT32 for the datatype.
673+
*var_len = false;
674+
*nullable = false;
675+
*list = false;
676+
*datatype = AttrDatatype::INT32;
677+
return;
678+
}
679+
647680
*var_len = !fixed_len_attr(attribute);
648681
*nullable = nullable_attr(attribute);
649682
*list = var_len_list_attr(attribute);
@@ -672,7 +705,7 @@ AttrDatatype InMemoryExporter::get_info_fmt_datatype(
672705
dataset->fmt_field_type(field_name, hdr);
673706
switch (htslib_type) {
674707
case BCF_HT_FLAG:
675-
return AttrDatatype::INT32;
708+
return AttrDatatype::FLAG;
676709
case BCF_HT_STR:
677710
return AttrDatatype::CHAR;
678711
case BCF_HT_INT:
@@ -964,6 +997,10 @@ bool InMemoryExporter::copy_info_fmt_value(
964997
for (unsigned i = 0; i < nelts; i++)
965998
decoded[i] = bcf_gt_allele(genotype[i]);
966999
return copy_cell(dest, decoded, nelts * sizeof(int), nelts, hdr);
1000+
} else if (is_flag(field_name)) {
1001+
// If this is a flag, convert the src pointer to a true or false value.
1002+
int flag = src == nullptr ? 0 : 1;
1003+
return copy_cell(dest, &flag, sizeof(int), 1, hdr);
9671004
} else {
9681005
return copy_cell(dest, src, nbytes, nelts, hdr);
9691006
}

libtiledbvcf/src/read/in_memory_exporter.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,23 @@ class InMemoryExporter : public Exporter {
255255
/** Reusable string buffer for temp results. */
256256
std::string str_buff_;
257257

258+
/**
259+
* Set of info field names with type flag.
260+
*
261+
* NOTE: Only info fields are allowed to have type flag.
262+
*/
263+
static std::set<std::string> info_flags_;
264+
258265
/* ********************************* */
259266
/* PRIVATE METHODS */
260267
/* ********************************* */
261268

269+
/** Adds a field to the set of info flags. */
270+
static void add_flag(const std::string& field);
271+
272+
/** Returns true if the given field is an info flag. */
273+
static bool is_flag(const std::string& field);
274+
262275
/**
263276
* Returns the ExportableAttribute corresponding to the given attribute name.
264277
*/

libtiledbvcf/src/write/writer_worker_v4.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,13 @@ void WriterWorkerV4::buffer_info_field(
464464
buff->append(&type, sizeof(int));
465465
buff->append(&num_vals, sizeof(int));
466466
if (val->dst) {
467-
buff->append(val->dst, num_vals * utils::bcf_type_size(type));
467+
if (type == BCF_HT_FLAG) {
468+
// Write a dummy value for flags.
469+
int flag = 1;
470+
buff->append(&flag, num_vals * utils::bcf_type_size(type));
471+
} else {
472+
buff->append(val->dst, num_vals * utils::bcf_type_size(type));
473+
}
468474
} else {
469475
// val->dst can be NULL if the only INFO value is a flag
470476
assert(num_vals == 1);

0 commit comments

Comments
 (0)