diff --git a/SparseDiffEngine b/SparseDiffEngine index bcdb0f0..63a38d2 160000 --- a/SparseDiffEngine +++ b/SparseDiffEngine @@ -1 +1 @@ -Subproject commit bcdb0f0e74670b80b0f60b7ff02338dfa325fdf0 +Subproject commit 63a38d2054d7dd82e117b2cb7f4afa802be138b6 diff --git a/sparsediffpy/_bindings/atoms/diag_mat.h b/sparsediffpy/_bindings/atoms/diag_mat.h new file mode 100644 index 0000000..284dfe7 --- /dev/null +++ b/sparsediffpy/_bindings/atoms/diag_mat.h @@ -0,0 +1,32 @@ +#ifndef ATOM_DIAG_MAT_H +#define ATOM_DIAG_MAT_H + +#include "common.h" + +static PyObject *py_make_diag_mat(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + return NULL; + } + + expr *node = new_diag_mat(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create diag_mat node"); + return NULL; + } + + expr_retain(node); + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_DIAG_MAT_H */ diff --git a/sparsediffpy/_bindings/atoms/upper_tri.h b/sparsediffpy/_bindings/atoms/upper_tri.h new file mode 100644 index 0000000..a007ba4 --- /dev/null +++ b/sparsediffpy/_bindings/atoms/upper_tri.h @@ -0,0 +1,32 @@ +#ifndef ATOM_UPPER_TRI_H +#define ATOM_UPPER_TRI_H + +#include "common.h" + +static PyObject *py_make_upper_tri(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + + if (!PyArg_ParseTuple(args, "O", &child_capsule)) + { + return NULL; + } + + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + return NULL; + } + + expr *node = new_upper_tri(child); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create upper_tri node"); + return NULL; + } + + expr_retain(node); + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_UPPER_TRI_H */ diff --git a/sparsediffpy/_bindings/atoms/vstack.h b/sparsediffpy/_bindings/atoms/vstack.h new file mode 100644 index 0000000..df561ad --- /dev/null +++ b/sparsediffpy/_bindings/atoms/vstack.h @@ -0,0 +1,51 @@ +#ifndef ATOM_VSTACK_H +#define ATOM_VSTACK_H + +#include "common.h" + +static PyObject *py_make_vstack(PyObject *self, PyObject *args) +{ + (void) self; + PyObject *list_obj; + if (!PyArg_ParseTuple(args, "O", &list_obj)) + { + return NULL; + } + if (!PyList_Check(list_obj)) + { + PyErr_SetString(PyExc_TypeError, + "First argument must be a list of expr capsules"); + return NULL; + } + Py_ssize_t n_args = PyList_Size(list_obj); + if (n_args == 0) + { + PyErr_SetString(PyExc_ValueError, "List of expr capsules cannot be empty"); + return NULL; + } + expr **expr_args = (expr **) calloc(n_args, sizeof(expr *)); + for (Py_ssize_t i = 0; i < n_args; ++i) + { + PyObject *item = PyList_GetItem(list_obj, i); + expr *e = (expr *) PyCapsule_GetPointer(item, EXPR_CAPSULE_NAME); + if (!e) + { + free(expr_args); + PyErr_SetString(PyExc_ValueError, "Invalid expr capsule in list"); + return NULL; + } + expr_args[i] = e; + } + int n_vars = expr_args[0]->n_vars; + expr *node = new_vstack(expr_args, (int) n_args, n_vars); + free(expr_args); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create vstack node"); + return NULL; + } + expr_retain(node); + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif // ATOM_VSTACK_H diff --git a/sparsediffpy/_bindings/bindings.c b/sparsediffpy/_bindings/bindings.c index f9b1ac9..31d68ba 100644 --- a/sparsediffpy/_bindings/bindings.c +++ b/sparsediffpy/_bindings/bindings.c @@ -8,6 +8,7 @@ #include "atoms/atanh.h" #include "atoms/broadcast.h" #include "atoms/cos.h" +#include "atoms/diag_mat.h" #include "atoms/diag_vec.h" #include "atoms/entr.h" #include "atoms/exp.h" @@ -40,8 +41,10 @@ #include "atoms/tanh.h" #include "atoms/trace.h" #include "atoms/transpose.h" +#include "atoms/upper_tri.h" #include "atoms/variable.h" #include "atoms/vector_mult.h" +#include "atoms/vstack.h" #include "atoms/xexp.h" /* Include problem bindings */ @@ -80,6 +83,8 @@ static PyMethodDef DNLPMethods[] = { {"make_hstack", py_make_hstack, METH_VARARGS, "Create hstack node from list of expr capsules and n_vars (make_hstack([e1, " "e2, ...], n_vars))"}, + {"make_vstack", py_make_vstack, METH_VARARGS, + "Create vstack node from list of expr capsules (make_vstack([e1, e2, ...]))"}, {"make_sum", py_make_sum, METH_VARARGS, "Create sum node"}, {"make_neg", py_make_neg, METH_VARARGS, "Create neg node"}, {"make_normal_cdf", py_make_normal_cdf, METH_VARARGS, "Create normal_cdf node"}, @@ -100,12 +105,14 @@ static PyMethodDef DNLPMethods[] = { "Create prod_axis_one node"}, {"make_sin", py_make_sin, METH_VARARGS, "Create sin node"}, {"make_cos", py_make_cos, METH_VARARGS, "Create cos node"}, + {"make_diag_mat", py_make_diag_mat, METH_VARARGS, "Create diag_mat node"}, {"make_diag_vec", py_make_diag_vec, METH_VARARGS, "Create diag_vec node"}, {"make_tan", py_make_tan, METH_VARARGS, "Create tan node"}, {"make_sinh", py_make_sinh, METH_VARARGS, "Create sinh node"}, {"make_tanh", py_make_tanh, METH_VARARGS, "Create tanh node"}, {"make_asinh", py_make_asinh, METH_VARARGS, "Create asinh node"}, {"make_atanh", py_make_atanh, METH_VARARGS, "Create atanh node"}, + {"make_upper_tri", py_make_upper_tri, METH_VARARGS, "Create upper_tri node"}, {"make_broadcast", py_make_broadcast, METH_VARARGS, "Create broadcast node"}, {"make_entr", py_make_entr, METH_VARARGS, "Create entr node"}, {"make_logistic", py_make_logistic, METH_VARARGS, "Create logistic node"},