diff --git a/include/atoms/affine.h b/include/atoms/affine.h index 94afe8b..3e71bf0 100644 --- a/include/atoms/affine.h +++ b/include/atoms/affine.h @@ -58,6 +58,12 @@ expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A); expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n, const double *data); +/* Kronecker product with constant on the left: Z = kron(C, u) where C is a + * constant sparse matrix and u is a (p x q) expression. Output shape + * (C->m * p, C->n * q). param_node must be NULL; the parameter path is + * reserved for a future change. */ +expr *new_kron_left(expr *param_node, expr *u, const CSR_Matrix *C, int p, int q); + /* Scalar multiplication: a * f(x) where a comes from param_node */ expr *new_scalar_mult(expr *param_node, expr *child); diff --git a/include/subexpr.h b/include/subexpr.h index 9c57d23..fc8e4ed 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -170,6 +170,30 @@ typedef struct matmul_expr int *idx_map_Hg; } matmul_expr; +/* Kronecker product with a constant on the left: Z = kron(C, X) where C is + * a constant (m x n) sparse matrix and X is an expression of shape (p x q). + * Output has shape (m*p, n*q). The atom is affine in X; the param_source + * slot is reserved for a future update that makes C an updatable parameter. + * + * We cache the active entries of C (one per nonzero of C) so that all + * inner loops run in O(nnz_C * p * q) rather than touching zero rows of + * the output. This automatically collapses to O(m * p * q) when C = I_m, + * with no special case in the code. */ +typedef struct kron_left_expr +{ + expr base; + CSR_Matrix *C; /* constant matrix, owned */ + int p, q; /* child shape (m, n are C->m, C->n) */ + /* active-entry tables (length C->nnz), filled in constructor */ + int n_active; + int *active_i; /* row index i of each nonzero */ + int *active_j; /* col index j of each nonzero */ + int *active_idx; /* index into C->x */ + /* parameter slot (not wired up yet — param_source must be NULL) */ + expr *param_source; + void (*refresh_param_values)(struct kron_left_expr *); +} kron_left_expr; + /* Index/slicing: y = child[indices] where indices is a list of flat positions */ typedef struct index_expr { diff --git a/src/atoms/affine/kron_left.c b/src/atoms/affine/kron_left.c new file mode 100644 index 0000000..28ffb2c --- /dev/null +++ b/src/atoms/affine/kron_left.c @@ -0,0 +1,311 @@ +/* + * Copyright 2026 Daniel Cederberg and William Zhang + * + * This file is part of the SparseDiffEngine project. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "atoms/affine.h" +#include "subexpr.h" +#include "utils/tracked_alloc.h" +#include +#include +#include +#include + +/* Kronecker product with constant on the left: Z = kron(C, X) where + * C has shape (m, n) and is a constant sparse matrix, + * X has shape (p, q) and is an expression. + * Output Z has shape (m*p, n*q), stored column-major as vec(Z) of length + * m*p*n*q. + * + * Key identity: Z[i*p+k, j*q+l] = C[i,j] * X[k,l]. + * In column-major: vec(Z)[r] with r = (j*q+l)*(m*p) + i*p + k + * depends on vec(X)[s] with s = l*p + k and coefficient C[i,j]. + * + * The atom is affine in X: each output row r (when C[i,j] != 0) is a + * scaled copy of child row s of the child's Jacobian, and the weighted + * Hessian inherits the child's sparsity with an adjoint accumulation + * over the same index pattern. + * + * All inner loops iterate only over nonzeros of C (cached in the + * active_i / active_j / active_idx tables at construction). No explicit + * identity-detection is needed: for C = I_m, nnz_C == m and the work + * naturally drops to O(m * p * q) without any special-case code. */ + +static void forward(expr *node, const double *u) +{ + kron_left_expr *lnode = (kron_left_expr *) node; + expr *child = node->left; + CSR_Matrix *C = lnode->C; + int p = lnode->p, q = lnode->q; + int mp = C->m * p; + + child->forward(child, u); + + memset(node->value, 0, (size_t) node->size * sizeof(double)); + + /* For each nonzero C[i,j], scatter the (p x q) block cij * X into + * position Z[i*p .. i*p+p-1, j*q .. j*q+q-1]. */ + for (int t = 0; t < lnode->n_active; t++) + { + int i = lnode->active_i[t]; + int j = lnode->active_j[t]; + double cij = C->x[lnode->active_idx[t]]; + for (int l = 0; l < q; l++) + { + int z_col_start = (j * q + l) * mp + i * p; + int x_col_start = l * p; + for (int k = 0; k < p; k++) + { + node->value[z_col_start + k] = cij * child->value[x_col_start + k]; + } + } + } +} + +static bool is_affine(const expr *node) +{ + return node->left->is_affine(node->left); +} + +/* Two-pass construction over active C entries × (l, k): + * pass 1 fills row_nnz[r] for every active output row, + * pass 2 writes column indices into the already-allocated CSR. + * Rows r that don't correspond to an active (i, j) stay at 0 nnz. + * + * Work: O(nnz_C * p * q * avg_nnz_per_Jchild_row). For C = I_m this is + * O(m * p * q * avg_Jchild_row_nnz), i.e. a factor-of-n reduction vs a + * naive iteration over every output row of Z. */ +static void jacobian_init_impl(expr *node) +{ + kron_left_expr *lnode = (kron_left_expr *) node; + expr *child = node->left; + CSR_Matrix *C = lnode->C; + int p = lnode->p, q = lnode->q; + int mp = C->m * p; + int out_size = node->size; + + jacobian_init(child); + CSR_Matrix *Jchild = child->jacobian; + + /* Pass 1: row_nnz[r] = Jchild row-nnz for active r, else 0. */ + int *row_nnz = (int *) SP_CALLOC((size_t) out_size, sizeof(int)); + for (int t = 0; t < lnode->n_active; t++) + { + int i = lnode->active_i[t]; + int j = lnode->active_j[t]; + for (int l = 0; l < q; l++) + { + int r_col_base = (j * q + l) * mp + i * p; + for (int k = 0; k < p; k++) + { + int s = l * p + k; + row_nnz[r_col_base + k] = Jchild->p[s + 1] - Jchild->p[s]; + } + } + } + + /* Cumulative sum into a local buffer; we'll memcpy into the + * Jacobian's p[] after allocation. */ + int *Jp = (int *) SP_MALLOC((size_t) (out_size + 1) * sizeof(int)); + int total_nnz = 0; + for (int r = 0; r < out_size; r++) + { + Jp[r] = total_nnz; + total_nnz += row_nnz[r]; + } + Jp[out_size] = total_nnz; + free(row_nnz); + + node->jacobian = new_csr_matrix(out_size, node->n_vars, total_nnz); + memcpy(node->jacobian->p, Jp, (size_t) (out_size + 1) * sizeof(int)); + free(Jp); + + /* Pass 2: column indices are a copy of the corresponding Jchild row. */ + for (int t = 0; t < lnode->n_active; t++) + { + int i = lnode->active_i[t]; + int j = lnode->active_j[t]; + for (int l = 0; l < q; l++) + { + int r_col_base = (j * q + l) * mp + i * p; + for (int k = 0; k < p; k++) + { + int s = l * p + k; + int r = r_col_base + k; + int cs = Jchild->p[s]; + int row_nnz_r = Jchild->p[s + 1] - cs; + int row_start = node->jacobian->p[r]; + memcpy(node->jacobian->i + row_start, Jchild->i + cs, + (size_t) row_nnz_r * sizeof(int)); + } + } + } +} + +static void eval_jacobian(expr *node) +{ + kron_left_expr *lnode = (kron_left_expr *) node; + expr *child = node->left; + CSR_Matrix *C = lnode->C; + CSR_Matrix *Jchild = child->jacobian; + CSR_Matrix *J = node->jacobian; + int p = lnode->p, q = lnode->q; + int mp = C->m * p; + + child->eval_jacobian(child); + + for (int t = 0; t < lnode->n_active; t++) + { + int i = lnode->active_i[t]; + int j = lnode->active_j[t]; + double cij = C->x[lnode->active_idx[t]]; + for (int l = 0; l < q; l++) + { + int r_col_base = (j * q + l) * mp + i * p; + for (int k = 0; k < p; k++) + { + int s = l * p + k; + int r = r_col_base + k; + int cs = Jchild->p[s]; + int row_nnz_r = Jchild->p[s + 1] - cs; + int row_start = J->p[r]; + for (int u = 0; u < row_nnz_r; u++) + { + J->x[row_start + u] = cij * Jchild->x[cs + u]; + } + } + } + } +} + +static void wsum_hess_init_impl(expr *node) +{ + expr *child = node->left; + + wsum_hess_init(child); + + /* Linear in X: Hessian sparsity equals the child's. */ + node->wsum_hess = new_csr_copy_sparsity(child->wsum_hess); + + /* Workspace for the reverse-mode weight vector passed down to child. */ + node->work->dwork = (double *) SP_MALLOC((size_t) child->size * sizeof(double)); +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + kron_left_expr *lnode = (kron_left_expr *) node; + expr *child = node->left; + CSR_Matrix *C = lnode->C; + int p = lnode->p, q = lnode->q; + int mp = C->m * p; + int child_size = child->size; + double *w_child = node->work->dwork; + + /* Adjoint of the forward pass: w_child[s] = sum_{(i,j,k,l): s=l*p+k} + * C[i,j] * w[(j*q+l)*mp + i*p + k]. */ + memset(w_child, 0, (size_t) child_size * sizeof(double)); + for (int t = 0; t < lnode->n_active; t++) + { + int i = lnode->active_i[t]; + int j = lnode->active_j[t]; + double cij = C->x[lnode->active_idx[t]]; + for (int l = 0; l < q; l++) + { + int r_col_base = (j * q + l) * mp + i * p; + for (int k = 0; k < p; k++) + { + int s = l * p + k; + w_child[s] += cij * w[r_col_base + k]; + } + } + } + + child->eval_wsum_hess(child, w_child); + memcpy(node->wsum_hess->x, child->wsum_hess->x, + (size_t) node->wsum_hess->nnz * sizeof(double)); +} + +static void free_type_data(expr *node) +{ + kron_left_expr *lnode = (kron_left_expr *) node; + free_csr_matrix(lnode->C); + free(lnode->active_i); + free(lnode->active_j); + free(lnode->active_idx); + if (lnode->param_source != NULL) + { + free_expr(lnode->param_source); + } + lnode->C = NULL; + lnode->active_i = NULL; + lnode->active_j = NULL; + lnode->active_idx = NULL; + lnode->param_source = NULL; +} + +expr *new_kron_left(expr *param_node, expr *u, const CSR_Matrix *C, int p, int q) +{ + if (u->size != p * q) + { + fprintf(stderr, + "Error in new_kron_left: child size %d != p*q = %d*%d = %d\n", + u->size, p, q, p * q); + exit(1); + } + + int m = C->m; + int n = C->n; + + kron_left_expr *lnode = (kron_left_expr *) SP_CALLOC(1, sizeof(kron_left_expr)); + expr *node = &lnode->base; + init_expr(node, m * p, n * q, u->n_vars, forward, jacobian_init_impl, + eval_jacobian, is_affine, wsum_hess_init_impl, eval_wsum_hess, + free_type_data); + node->left = u; + expr_retain(u); + + lnode->p = p; + lnode->q = q; + lnode->C = new_csr(C); + + /* Precompute active (i, j) tuples and their offset into C->x. */ + lnode->n_active = C->nnz; + lnode->active_i = (int *) SP_MALLOC((size_t) C->nnz * sizeof(int)); + lnode->active_j = (int *) SP_MALLOC((size_t) C->nnz * sizeof(int)); + lnode->active_idx = (int *) SP_MALLOC((size_t) C->nnz * sizeof(int)); + int t = 0; + for (int i = 0; i < m; i++) + { + for (int idx = C->p[i]; idx < C->p[i + 1]; idx++) + { + lnode->active_i[t] = i; + lnode->active_j[t] = C->i[idx]; + lnode->active_idx[t] = idx; + t++; + } + } + assert(t == C->nnz); + + /* Parameter slot is reserved but not yet wired up. */ + lnode->param_source = param_node; + if (param_node != NULL) + { + fprintf(stderr, "Error in new_kron_left: parameter for kron C " + "not supported yet\n"); + exit(1); + } + + return node; +} diff --git a/tests/all_tests.c b/tests/all_tests.c index efaad33..290b3dd 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -9,6 +9,7 @@ #include "forward_pass/affine/test_broadcast.h" #include "forward_pass/affine/test_diag_mat.h" #include "forward_pass/affine/test_hstack.h" +#include "forward_pass/affine/test_kron_left.h" #include "forward_pass/affine/test_left_matmul_dense.h" #include "forward_pass/affine/test_linear_op.h" #include "forward_pass/affine/test_neg.h" @@ -28,6 +29,7 @@ #include "jacobian_tests/affine/test_diag_mat.h" #include "jacobian_tests/affine/test_hstack.h" #include "jacobian_tests/affine/test_index.h" +#include "jacobian_tests/affine/test_kron_left.h" #include "jacobian_tests/affine/test_left_matmul.h" #include "jacobian_tests/affine/test_neg.h" #include "jacobian_tests/affine/test_promote.h" @@ -68,6 +70,7 @@ #include "wsum_hess/affine/test_diag_mat.h" #include "wsum_hess/affine/test_hstack.h" #include "wsum_hess/affine/test_index.h" +#include "wsum_hess/affine/test_kron_left.h" #include "wsum_hess/affine/test_left_matmul.h" #include "wsum_hess/affine/test_right_matmul.h" #include "wsum_hess/affine/test_scalar_mult.h" @@ -134,6 +137,8 @@ int main(void) mu_run_test(test_forward_prod_axis_one, tests_run); mu_run_test(test_matmul, tests_run); mu_run_test(test_left_matmul_dense, tests_run); + mu_run_test(test_kron_left_forward, tests_run); + mu_run_test(test_kron_left_forward_identity, tests_run); mu_run_test(test_diag_mat_forward, tests_run); mu_run_test(test_upper_tri_forward_4x4, tests_run); @@ -212,6 +217,8 @@ int main(void) mu_run_test(test_jacobian_left_matmul_log, tests_run); mu_run_test(test_jacobian_left_matmul_log_matrix, tests_run); mu_run_test(test_jacobian_left_matmul_exp_composite, tests_run); + mu_run_test(test_jacobian_kron_left_log, tests_run); + mu_run_test(test_jacobian_kron_left_log_matrix, tests_run); mu_run_test(test_jacobian_right_matmul_log, tests_run); mu_run_test(test_jacobian_right_matmul_log_vector, tests_run); mu_run_test(test_jacobian_matmul, tests_run); @@ -276,6 +283,8 @@ int main(void) mu_run_test(test_wsum_hess_left_matmul, tests_run); mu_run_test(test_wsum_hess_left_matmul_matrix, tests_run); mu_run_test(test_wsum_hess_left_matmul_exp_composite, tests_run); + mu_run_test(test_wsum_hess_kron_left, tests_run); + mu_run_test(test_wsum_hess_kron_left_composite, tests_run); mu_run_test(test_wsum_hess_matmul, tests_run); mu_run_test(test_wsum_hess_matmul_yx, tests_run); mu_run_test(test_wsum_hess_right_matmul, tests_run); diff --git a/tests/forward_pass/affine/test_kron_left.h b/tests/forward_pass/affine/test_kron_left.h new file mode 100644 index 0000000..3d0e7b2 --- /dev/null +++ b/tests/forward_pass/affine/test_kron_left.h @@ -0,0 +1,97 @@ +#include +#include +#include + +#include "atoms/affine.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_kron_left_forward(void) +{ + /* Test: Z = kron(C, X) where + * C is 2x2 sparse: [[1, 2], [0, 3]] + * X is 2x2 variable (col-major): [[1, 3], [2, 4]] + * + * kron(C, X) = [[1*X, 2*X], [0*X, 3*X]] + * = [[1, 3, 2, 6], + * [2, 4, 4, 8], + * [0, 0, 3, 9], + * [0, 0, 6, 12]] + */ + expr *X = new_variable(2, 2, 0, 4); + + CSR_Matrix *C = new_csr_matrix(2, 2, 3); + int C_p[3] = {0, 2, 3}; + int C_i[3] = {0, 1, 1}; + double C_x[3] = {1.0, 2.0, 3.0}; + memcpy(C->p, C_p, 3 * sizeof(int)); + memcpy(C->i, C_i, 3 * sizeof(int)); + memcpy(C->x, C_x, 3 * sizeof(double)); + + expr *Z = new_kron_left(NULL, X, C, 2, 2); + + /* X = [[1,3],[2,4]] in column-major */ + double u[4] = {1.0, 2.0, 3.0, 4.0}; + + Z->forward(Z, u); + + /* (4x4) column-major */ + double expected[16] = { + 1.0, 2.0, 0.0, 0.0, /* col 0 */ + 3.0, 4.0, 0.0, 0.0, /* col 1 */ + 2.0, 4.0, 3.0, 6.0, /* col 2 */ + 6.0, 8.0, 9.0, 12.0 /* col 3 */ + }; + + mu_assert("kron_left d1 != 4", Z->d1 == 4); + mu_assert("kron_left d2 != 4", Z->d2 == 4); + mu_assert("kron_left size != 16", Z->size == 16); + mu_assert("kron_left forward values", cmp_double_array(Z->value, expected, 16)); + + free_csr_matrix(C); + free_expr(Z); + return 0; +} + +const char *test_kron_left_forward_identity(void) +{ + /* Identity path: Z = kron(I_3, X) with X a (2 x 2) variable. + * Result is block-diagonal: three copies of X stacked along the + * diagonal. Verifies the sparse-driven path collapses to the + * right block structure with no identity-detection code. */ + expr *X = new_variable(2, 2, 0, 4); + + CSR_Matrix *I3 = new_csr_matrix(3, 3, 3); + int Ip[4] = {0, 1, 2, 3}; + int Ii[3] = {0, 1, 2}; + double Ix[3] = {1.0, 1.0, 1.0}; + memcpy(I3->p, Ip, 4 * sizeof(int)); + memcpy(I3->i, Ii, 3 * sizeof(int)); + memcpy(I3->x, Ix, 3 * sizeof(double)); + + expr *Z = new_kron_left(NULL, X, I3, 2, 2); + + /* X = [[1,3],[2,4]] in column-major */ + double u[4] = {1.0, 2.0, 3.0, 4.0}; + Z->forward(Z, u); + + /* kron(I_3, X) is 6x6, column-major: each column holds one + * column of X at a different row offset. */ + double expected[36] = { + 1.0, 2.0, 0.0, 0.0, 0.0, 0.0, /* col 0: X[:,0] at rows 0-1 */ + 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, /* col 1: X[:,1] at rows 0-1 */ + 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, /* col 2: X[:,0] at rows 2-3 */ + 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, /* col 3: X[:,1] at rows 2-3 */ + 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, /* col 4: X[:,0] at rows 4-5 */ + 0.0, 0.0, 0.0, 0.0, 3.0, 4.0, /* col 5: X[:,1] at rows 4-5 */ + }; + + mu_assert("kron_left identity d1 != 6", Z->d1 == 6); + mu_assert("kron_left identity d2 != 6", Z->d2 == 6); + mu_assert("kron_left identity values", cmp_double_array(Z->value, expected, 36)); + + free_csr_matrix(I3); + free_expr(Z); + return 0; +} diff --git a/tests/jacobian_tests/affine/test_kron_left.h b/tests/jacobian_tests/affine/test_kron_left.h new file mode 100644 index 0000000..07c9ba4 --- /dev/null +++ b/tests/jacobian_tests/affine/test_kron_left.h @@ -0,0 +1,99 @@ +#include +#include +#include + +#include "atoms/affine.h" +#include "atoms/elementwise_restricted_dom.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_jacobian_kron_left_log(void) +{ + /* Jacobian of kron(C, log(x)) where + * x is 2x1 variable at x = [1, 2] + * C is 2x2 sparse: [[1, 2], [0, 3]] + * Output kron(C, log(x)) is 4x2, vectorized to 8x1. + * + * Each active output row r (C[i,j] != 0) is C[i,j] * d log(x_s)/dx_s + * at col s = l*p + k. Since d log(x_k)/dx_v = delta_{k,v}/x_k: + * r=0: [1, 0]; r=1: [0, 1/2]; + * r=2: zero; r=3: zero; + * r=4: [2, 0]; r=5: [0, 1]; + * r=6: [3, 0]; r=7: [0, 1.5]. + */ + double x_vals[2] = {1.0, 2.0}; + expr *x = new_variable(2, 1, 0, 2); + + CSR_Matrix *C = new_csr_matrix(2, 2, 3); + int C_p[3] = {0, 2, 3}; + int C_i[3] = {0, 1, 1}; + double C_x[3] = {1.0, 2.0, 3.0}; + memcpy(C->p, C_p, 3 * sizeof(int)); + memcpy(C->i, C_i, 3 * sizeof(int)); + memcpy(C->x, C_x, 3 * sizeof(double)); + + expr *log_x = new_log(x); + expr *Z = new_kron_left(NULL, log_x, C, 2, 1); + + Z->forward(Z, x_vals); + jacobian_init(Z); + Z->eval_jacobian(Z); + + double expected_x[6] = {1.0, 0.5, 2.0, 1.0, 3.0, 1.5}; + int expected_i[6] = {0, 1, 0, 1, 0, 1}; + int expected_p[9] = {0, 1, 2, 2, 2, 3, 4, 5, 6}; + + mu_assert("kron_left jac vals fail", + cmp_double_array(Z->jacobian->x, expected_x, 6)); + mu_assert("kron_left jac cols fail", + cmp_int_array(Z->jacobian->i, expected_i, 6)); + mu_assert("kron_left jac rows fail", + cmp_int_array(Z->jacobian->p, expected_p, 9)); + + free_csr_matrix(C); + free_expr(Z); + return 0; +} + +const char *test_jacobian_kron_left_log_matrix(void) +{ + /* Jacobian of kron(C, log(x)) where + * x is 2x2 variable, col-major [1,2,3,4] + * C is 2x1 sparse: [[1], [2]] + * Output is 4x2, vectorized to 8x1. Every output row is active. + * J[r,var] = C[i,j] / x[s] at col s, zero elsewhere. + */ + double x_vals[4] = {1.0, 2.0, 3.0, 4.0}; + expr *x = new_variable(2, 2, 0, 4); + + CSR_Matrix *C = new_csr_matrix(2, 1, 2); + int C_p[3] = {0, 1, 2}; + int C_i[2] = {0, 0}; + double C_x[2] = {1.0, 2.0}; + memcpy(C->p, C_p, 3 * sizeof(int)); + memcpy(C->i, C_i, 2 * sizeof(int)); + memcpy(C->x, C_x, 2 * sizeof(double)); + + expr *log_x = new_log(x); + expr *Z = new_kron_left(NULL, log_x, C, 2, 2); + + Z->forward(Z, x_vals); + jacobian_init(Z); + Z->eval_jacobian(Z); + + double expected_x[8] = {1.0, 0.5, 2.0, 1.0, 1.0 / 3.0, 0.25, 2.0 / 3.0, 0.5}; + int expected_i[8] = {0, 1, 0, 1, 2, 3, 2, 3}; + int expected_p[9] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + + mu_assert("kron_left matrix jac vals fail", + cmp_double_array(Z->jacobian->x, expected_x, 8)); + mu_assert("kron_left matrix jac cols fail", + cmp_int_array(Z->jacobian->i, expected_i, 8)); + mu_assert("kron_left matrix jac rows fail", + cmp_int_array(Z->jacobian->p, expected_p, 9)); + + free_csr_matrix(C); + free_expr(Z); + return 0; +} diff --git a/tests/wsum_hess/affine/test_kron_left.h b/tests/wsum_hess/affine/test_kron_left.h new file mode 100644 index 0000000..55f3a17 --- /dev/null +++ b/tests/wsum_hess/affine/test_kron_left.h @@ -0,0 +1,91 @@ +#include +#include +#include +#include + +#include "atoms/affine.h" +#include "atoms/elementwise_full_dom.h" +#include "atoms/elementwise_restricted_dom.h" +#include "expr.h" +#include "minunit.h" +#include "numerical_diff.h" +#include "test_helpers.h" + +const char *test_wsum_hess_kron_left(void) +{ + /* wsum_hess of kron(C, log(x)) where + * x is 2x1 variable at x = [1, 2] + * C is 2x2 sparse: [[1, 2], [0, 3]] + * Weights w = [1, 2, 3, 4, 5, 6, 7, 8]. + * + * w_child[s] = sum over active (i,j) of C[i,j] * w[(j*q+l)*mp + i*p + k] + * (p=2, q=1, mp=4): + * w_child[0] = 1*w[0] + 2*w[4] + 3*w[6] = 1 + 10 + 21 = 32 + * w_child[1] = 1*w[1] + 2*w[5] + 3*w[7] = 2 + 12 + 24 = 38 + * Hessian of log: H[k,k] = -1/x[k]^2 -> [-32, -9.5]. */ + double x_vals[2] = {1.0, 2.0}; + double w[8] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}; + + expr *x = new_variable(2, 1, 0, 2); + + CSR_Matrix *C = new_csr_matrix(2, 2, 3); + int C_p[3] = {0, 2, 3}; + int C_i[3] = {0, 1, 1}; + double C_x[3] = {1.0, 2.0, 3.0}; + memcpy(C->p, C_p, 3 * sizeof(int)); + memcpy(C->i, C_i, 3 * sizeof(int)); + memcpy(C->x, C_x, 3 * sizeof(double)); + + expr *log_x = new_log(x); + expr *Z = new_kron_left(NULL, log_x, C, 2, 1); + + Z->forward(Z, x_vals); + jacobian_init(Z); + wsum_hess_init(Z); + Z->eval_wsum_hess(Z, w); + + double expected_x[2] = {-32.0, -9.5}; + int expected_i[2] = {0, 1}; + int expected_p[3] = {0, 1, 2}; + + mu_assert("kron_left hess vals fail", + cmp_double_array(Z->wsum_hess->x, expected_x, 2)); + mu_assert("kron_left hess cols fail", + cmp_int_array(Z->wsum_hess->i, expected_i, 2)); + mu_assert("kron_left hess rows fail", + cmp_int_array(Z->wsum_hess->p, expected_p, 3)); + + free_csr_matrix(C); + free_expr(Z); + return 0; +} + +const char *test_wsum_hess_kron_left_composite(void) +{ + /* Verify weight propagation through kron_left when the child has a + * non-trivial Hessian, by numerical differentiation against + * kron(C, exp(x)). exp is a full-domain elementwise atom so it + * composes on top of a variable child correctly here. */ + double x_vals[2] = {0.3, -0.7}; + double w[8] = {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5}; + + expr *x = new_variable(2, 1, 0, 2); + + CSR_Matrix *C = new_csr_matrix(2, 2, 3); + int C_p[3] = {0, 2, 3}; + int C_i[3] = {0, 1, 1}; + double C_x[3] = {1.0, 2.0, 3.0}; + memcpy(C->p, C_p, 3 * sizeof(int)); + memcpy(C->i, C_i, 3 * sizeof(int)); + memcpy(C->x, C_x, 3 * sizeof(double)); + + expr *exp_x = new_exp(x); + expr *Z = new_kron_left(NULL, exp_x, C, 2, 1); + + mu_assert("kron_left composite wsum_hess check failed", + check_wsum_hess(Z, x_vals, w, NUMERICAL_DIFF_DEFAULT_H)); + + free_csr_matrix(C); + free_expr(Z); + return 0; +}