Skip to content
327 changes: 326 additions & 1 deletion src/array_ops.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,291 @@ void detrend(bp::object & tod, const std::string & method, const int linear_ncou
}
}

template <typename T>
int _find_bin_index(const T* bin_edges, T value, int nbins) {
int left = 0;
int right = nbins;

// Mimic np.clip and assign out-of-bounds points to
// first and last bin
if (value < bin_edges[left]) {
return 0;
}
else if (value >= bin_edges[right]) {
return right - 1;
}

while (left < right) {
int mid = left + (right - left) / 2;

if (value >= bin_edges[mid] && value < bin_edges[mid + 1]) {
return mid;
}
else if (value >= bin_edges[mid + 1]) {
left = mid + 1;
}
else {
right = mid;
}
}

return -1;
}

template <typename T>
void _bin_signal(const bp::object & bin_by, const bp::object & signal,
const bp::object & weight, bp::object & binned_sig,
bp::object & binned_sig_sigma, bp::object & bin_counts,
const bp::object & bin_edges, const T lower, const T upper,
const int* flags_data=nullptr,
const std::vector<int> & flags_shape={0,0},
const int flags_stride=0)
{
// signal
BufferWrapper<T> signal_buf ("signal", signal, false, std::vector<int>{-1, -1});
if (signal_buf->strides[1] != signal_buf->itemsize)
throw ValueError_exception("Argument 'signal' must be contiguous in last axis.");
const int ndets = signal_buf->shape[0];
const int nsamps = signal_buf->shape[1];
T* signal_data = (T*)signal_buf->buf;

// Check flags shape
bool is_flags_2d = false;
if (flags_data) {
if (flags_shape.size() == 2) {
is_flags_2d = true;
if (flags_shape[0] != ndets || flags_shape[1] != nsamps) {
throw ValueError_exception("2D 'flags' array has incorrect shape");
}
}
else if (flags_shape[0] != nsamps) {
throw ValueError_exception("1D 'flags' array has incorrect shape");
}
}

// bin_by
BufferWrapper<T> bin_by_buf ("bin_by", bin_by, false, std::vector<int>{nsamps});
if (bin_by_buf->strides[0] != bin_by_buf->itemsize)
throw ValueError_exception("Argument 'bin_by' must be a C-contiguous 1d array");
T* bin_by_data = (T*)bin_by_buf->buf;

// weight
BufferWrapper<T> weight_buf_temp ("weight", weight, true);
if (weight_buf_temp->ndim == 1 && weight_buf_temp->strides[0] != weight_buf_temp->itemsize)
throw ValueError_exception("Argument 'weight' must be a C-contiguous 1d array");
else if (weight_buf_temp->ndim == 2 && weight_buf_temp->strides[1] != weight_buf_temp->itemsize)
throw ValueError_exception("Argument 'weight' must be contiguous in last axis.");

// Check weight dimensions
bool is_weight_2d = false;
std::vector<int> weight_dims;
if (weight_buf_temp->ndim == 2) {
weight_dims.push_back(ndets);
is_weight_2d = true;
}
if (weight_buf_temp->ndim >= 1) {
weight_dims.push_back(nsamps);
}

BufferWrapper<T> weight_buf ("weight", weight, false, weight_dims);
T* weight_data = (T*)weight_buf->buf;

// bin_edges
BufferWrapper<T> bin_edges_buf ("bin_edges", bin_edges, false, std::vector<int>{-1});
if (bin_edges_buf->strides[0] != bin_edges_buf->itemsize)
throw ValueError_exception("Argument 'bin_edges' must be a C-contiguous 1d array");
const int nbins = bin_edges_buf->shape[0] - 1;
T* bin_edges_data = (T*)bin_edges_buf->buf;

// binned_sig
BufferWrapper<T> binned_sig_buf ("binned_sig", binned_sig, false, std::vector<int>{ndets, nbins});
if (binned_sig_buf->strides[1] != binned_sig_buf->itemsize)
throw ValueError_exception("Argument 'binned_sig' must be contiguous in last axis.");
T* binned_sig_data = (T*)binned_sig_buf->buf;

// binned_sig_sigma
BufferWrapper<T> binned_sig_sigma_buf ("binned_sig_sigma", binned_sig_sigma, false, std::vector<int>{ndets, nbins});
if (binned_sig_sigma_buf->strides[1] != binned_sig_sigma_buf->itemsize)
throw ValueError_exception("Argument 'binned_sig_sigma' must be contiguous in last axis.");
T* binned_sig_sigma_data = (T*)binned_sig_sigma_buf->buf;

// bin_counts
BufferWrapper<int> bin_counts_buf ("bin_counts", bin_counts, false, std::vector<int>{ndets, nbins});
if (bin_counts_buf->strides[1] != bin_counts_buf->itemsize)
throw ValueError_exception("Argument 'bin_counts' must be contiguous in last axis.");
int* bin_counts_data = (int*)bin_counts_buf->buf;

// Strides
int signal_stride = signal_buf->strides[0] / sizeof(T);
int weight_stride = 0;
if (is_weight_2d) {
weight_stride = weight_buf->strides[0] / sizeof(T);
}
int binned_sig_stride = binned_sig_buf->strides[0] / sizeof(T);
int binned_sig_sigma_stride = binned_sig_sigma_buf->strides[0] / sizeof(T);
int bin_counts_stride = bin_counts_buf->strides[0] / sizeof(int);

// Map from data column to bin index
int* bin_indices = (int*) malloc(nsamps * sizeof(int));
for (int i = 0; i < nsamps; ++i) {
bin_indices[i] = _find_bin_index(bin_edges_data, bin_by_data[i], nbins);
}

// Populate bin_counts_data up front if no flag array given since its
// faster
if (!flags_data) {
for (int i = 0; i < nbins; ++i) {
bin_counts_data[i] = 0;
}
for (int i = 0; i < nsamps; ++i) {
if (bin_by[i] < lower || bin_by[i] > upper)
continue;
int bin = bin_indices[i];
bin_counts_data[bin] += weight_data[i];
}

// Set all other detectors to first det bins if no flags
#pragma omp parallel for
for (int i = 1; i < ndets; ++i) {
int binned_ioff = i * bin_counts_stride;
int* bin_counts_row = bin_counts_data + binned_ioff;
for (int j = 0; j < nbins; ++j) {
bin_counts_row[j] = bin_counts_data[j];
}
}
}

T* binned_sig_sq_mean = (T*) malloc(nbins * ndets * sizeof(T));

#pragma omp parallel for
for (int i = 0; i < ndets; ++i) {
int ioff = i * signal_stride;
int binned_ioff = i * bin_counts_stride;
int weight_ioff = i * weight_stride;

T* signal_row = signal_data + ioff;
T* binned_sig_row = binned_sig_data + (i * binned_sig_stride);
T* binned_sig_sq_mean_row = binned_sig_sq_mean + (i * nbins);
T* binned_sig_sigma_row = binned_sig_sigma_data + (i * binned_sig_sigma_stride);
T* weight_row = weight_data + weight_stride;
int* bin_counts_row = bin_counts_data + binned_ioff;

// Zero out binned data
for (int j = 0; j < nbins; ++j) {
binned_sig_row[j] = 0;
binned_sig_sq_mean_row[j] = 0;

if (flags_data) {
bin_counts_row[j] = 0;
}
}

// Populate binned data
for (int j = 0; j < nsamps; ++j) {
bool samp_flagged = false;
if (flags_data) {
int flags_ioff = i * flags_stride;
samp_flagged = flags_data[flags_ioff + j];
}
if (!samp_flagged) {
int bin = bin_indices[j];
if (flags_data) {
bin_counts_row[bin] += weight_row[j];
}

if (bin_counts_row[bin] > 0) {
T ws = weight_row[j] * signal_row[j];
binned_sig_row[bin] += ws;
binned_sig_sq_mean_row[bin] += ws * ws;
}
}
}
// Normalize
for (int j = 0; j < nbins; ++j) {
if (bin_counts_row[j] > 0) {
binned_sig_row[j] /= bin_counts_row[j];
binned_sig_sq_mean_row[j] /= bin_counts_row[j];
binned_sig_sigma_row[j] =
std::sqrt(std::abs(binned_sig_sq_mean_row[j] -
binned_sig_row[j] * binned_sig_row[j])) /
std::sqrt(bin_counts_row[j]);
}
else {
binned_sig_row[j] = std::numeric_limits<T>::quiet_NaN();
binned_sig_sigma_row[j] = std::numeric_limits<T>::quiet_NaN();
}
}
}

free(bin_indices);
free(binned_sig_sq_mean);
}

void bin_signal(const bp::object & bin_by, const bp::object & signal,
const bp::object & weight, bp::object & binned_sig,
bp::object & binned_sig_sigma, bp::object & bin_counts,
const bp::object & bin_edges, const double lower,
const double upper)
{
// Get data type
int dtype = get_dtype(signal);

if (dtype == NPY_FLOAT) {
_bin_signal<float>(bin_by, signal, weight, binned_sig, binned_sig_sigma,
bin_counts, bin_edges, (float)lower, (float)upper);
}
else if (dtype == NPY_DOUBLE) {
_bin_signal<double>(bin_by, signal, weight, binned_sig, binned_sig_sigma,
bin_counts, bin_edges, (double)lower, (double)upper);
}
else {
throw TypeError_exception("Only float32 or float64 arrays are supported.");
}
}

void bin_flagged_signal(const bp::object & bin_by, const bp::object & signal,
const bp::object & weight, bp::object & binned_sig,
bp::object & binned_sig_sigma, bp::object & bin_counts,
const bp::object & bin_edges, const double lower,
const double upper, const bp::object & flags)
{
// Get data type
int dtype = get_dtype(signal);

// flags
BufferWrapper<int> flags_buf ("flags", flags, false);
if (flags_buf->ndim == 1 && flags_buf->strides[0] != flags_buf->itemsize)
throw ValueError_exception("Argument 'flags' must be a C-contiguous 1d array");
else if (flags_buf->ndim == 2 && flags_buf->strides[1] != flags_buf->itemsize)
throw ValueError_exception("Argument 'flags' must be contiguous in last axis.");

std::vector<int> flags_shape;
flags_shape.push_back(flags_buf->shape[0]);

int flags_stride = 0;
if (flags_buf->ndim == 2) {
flags_shape.push_back(flags_buf->shape[1]);
flags_stride = flags_buf->strides[0] / sizeof(int);
}

int* flags_data = (int*)flags_buf->buf;

if (dtype == NPY_FLOAT) {
_bin_signal<float>(bin_by, signal, weight, binned_sig, binned_sig_sigma,
bin_counts, bin_edges, (float)lower, (float)upper,
flags_data, flags_shape, flags_stride);
}
else if (dtype == NPY_DOUBLE) {
_bin_signal<double>(bin_by, signal, weight, binned_sig, binned_sig_sigma,
bin_counts, bin_edges, (double)lower, (double)upper,
flags_data, flags_shape, flags_stride);
}
else {
throw TypeError_exception("Only float32 or float64 arrays are supported.");
}
}


PYBINDINGS("so3g")
{
bp::def("nmat_detvecs_apply", nmat_detvecs_apply);
Expand Down Expand Up @@ -1418,4 +1703,44 @@ PYBINDINGS("so3g")
" linear_ncount: Number (int) of samples to use on each end, when measuring mean level for 'linear'"
" detrend. Must be a positive integer or -1. If -1, nsamps / 2 will be used. Values "
" larger than 1 suppress the influence of white noise.\n");
}
bp::def("bin_signal", bin_signal,
"bin_signal(bin_by, signal, weight, binned_sig, binned_sig_sigma, bin_counts, bin_edges, lower, upper)"
"\n"
"Bin time-ordered data by ``bin_by`` and return the binned signal and its standard deviation.\n"
"This function uses OMP to parallelize over the dets (rows) axis. Supports unequal bin widths.\n"
"Args:\n"
" bin_by: the array (float32/float64) by which signal is binned with shape (nsamp)\n"
" signal: the signal array (float32/float64) to be binned with shape (ndet,nsamp)\n"
" weight: array (float32/float64) of weights for the signal values. May have shapes\n"
" of (nsamps) or (ndets, nsamps)\n"
" binned_sig: binned signal array (float32/float64) with shape (ndet,nbin).\n"
" Modified in place.\n"
" binned_sig_sigma: estimated sigma of binned signal (float32/float64) with shape (ndet,nbin).\n"
" Modified in place."
" bin_counts: counts of binned samples (int32) with shape (ndet,nbin). Modified in place.\n"
" bin_edges: array (float32/float64) of bin edges with length=nbins+1. Must be monotonically increasing but\n"
" but may have different widths.\n"
" lower: lower bin range (float64). Data points falling outside this range will be ignored.\n"
" upper: upper bin range (float64). Data points falling outside this range will be ignored.\n");
bp::def("bin_flagged_signal", bin_flagged_signal,
"bin_signal(bin_by, signal, weight, binned_sig, binned_sig_sigma, bin_counts, bin_edges, lower, upper, flags)\n"
"\n"
"Bin time-ordered data by ``bin_by`` and return the binned signal and its standard deviation.\n"
"This function uses OMP to parallelize over the dets (rows) axis. Supports unequal bin widths.\n"
"Args:\n"
" bin_by: the array (float32/float64) by which signal is binned with shape (nsamp)\n"
" signal: the signal array (float32/float64) to be binned with shape (ndet,nsamp)\n"
" weight: array (float32/float64) of weights for the signal values. May have shapes\n"
" of (nsamps) or (ndets, nsamps)\n"
" binned_sig: binned signal array (float32/float64) with shape (ndet,nbin).\n"
" Modified in place.\n"
" binned_sig_sigma: estimated sigma of binned signal (float32/float64) with shape (ndet,nbin).\n"
" Modified in place.\n"
" bin_counts: counts of binned samples (int32) with shape (ndet,nbin). Modified in place.\n"
" bin_edges: array (float32/float64) of bin edges with length=nbins+1. Must be monotonically increasing but\n"
" but may have different widths.\n"
" lower: lower bin range (float64). Data points falling outside this range will be ignored.\n"
" upper: upper bin range (float64). Data points falling outside this range will be ignored.\n"
" flags: array (int32) indicating whether to exclude flagged samples when binning the signal.\n"
" Can be of shape (nsamp) or (ndet,nsamp).\n");
}
Loading