Skip to content

Commit 53e1fd9

Browse files
committed
refactor callback w/o closures
1 parent 1aaca3f commit 53e1fd9

2 files changed

Lines changed: 117 additions & 136 deletions

File tree

src/odrpack/__odrpack.py.cpp

Lines changed: 116 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <nanobind/stl/optional.h>
55
#include <nanobind/stl/string.h>
66

7+
#include <algorithm>
78
#include <array>
89
#include <iostream>
910
#include <map>
@@ -17,16 +18,73 @@ namespace nb = nanobind;
1718
using namespace nanobind::literals;
1819

1920
/*
20-
The following class is necessary to ensure that the static variables used to store the callback
21-
functions are automatically reset upon normal or abnormal exit from the `odr_wrapper` function.
22-
From: https://github.com/libprima/prima/blob/main/python/_prima.cpp
21+
A container used to pass around the model function and its jacobians without creating a closure.
2322
*/
24-
class SelfCleaningPyObject {
25-
nb::object &obj;
23+
struct Context {
24+
nb::callable fcn_f;
25+
nb::callable fcn_fjacb;
26+
nb::callable fcn_fjacd;
27+
};
28+
29+
/*
30+
Callback function invoked from `odrpack` to evaluate the model function and, optionally, its
31+
Jacobians. The actual python functions are stored in the `Context` struct, which is passed through
32+
the void pointer argument `data`. This is used to call the appropriate functions without
33+
creating closures.
34+
*/
35+
void fcn(const int *n_ptr, const int *m_ptr, const int *q_ptr, const int *npar_ptr, const int *ldifx_ptr,
36+
const double beta[], const double xplusd[], const int ifixb[], const int ifixx[],
37+
const int *ideval_ptr, double f[], double fjacb[], double fjacd[], int *istop,
38+
void *data) {
39+
// Dereference/cast scalar inputs
40+
auto n = static_cast<size_t>(*n_ptr);
41+
auto m = static_cast<size_t>(*m_ptr);
42+
auto q = static_cast<size_t>(*q_ptr);
43+
auto npar = static_cast<size_t>(*npar_ptr);
44+
auto ideval = *ideval_ptr;
45+
46+
// Create NumPy arrays that wrap the input C-style arrays, without copying the data
47+
nb::ndarray<const double, nb::numpy> beta_ndarray(beta, {npar});
48+
nb::ndarray<const double, nb::numpy> xplusd_ndarray(
49+
xplusd,
50+
(m == 1) ? std::initializer_list<size_t>{n} : std::initializer_list<size_t>{m, n});
51+
52+
// Retrieve the model context
53+
const auto *context = static_cast<const Context *>(data);
54+
55+
*istop = 0;
56+
try {
57+
// Evaluate model function
58+
if (ideval % 10 > 0) {
59+
auto f_pyobject = context->fcn_f(xplusd_ndarray, beta_ndarray);
60+
auto f_ndarray = nb::cast<nb::ndarray<const double, nb::c_contig>>(f_pyobject);
61+
std::copy_n(f_ndarray.data(), q * n, f);
62+
}
63+
64+
// Model partial derivatives wrt `beta`
65+
if ((ideval / 10) % 10 != 0) {
66+
auto fjacb_pyobject = context->fcn_fjacb(xplusd_ndarray, beta_ndarray);
67+
auto fjacb_ndarray = nb::cast<nb::ndarray<const double, nb::c_contig>>(fjacb_pyobject);
68+
std::copy_n(fjacb_ndarray.data(), q * npar * n, fjacb);
69+
}
2670

27-
public:
28-
SelfCleaningPyObject(nb::object &obj) : obj(obj) {}
29-
~SelfCleaningPyObject() { obj = nb::none(); }
71+
// Model partial derivatives wrt `delta`
72+
if ((ideval / 100) % 10 != 0) {
73+
auto fjacd_pyobject = context->fcn_fjacd(xplusd_ndarray, beta_ndarray);
74+
auto fjacd_ndarray = nb::cast<nb::ndarray<const double, nb::c_contig>>(fjacd_pyobject);
75+
std::copy_n(fjacd_ndarray.data(), q * m * n, fjacd);
76+
}
77+
78+
} catch (const nb::python_error &e) {
79+
// temporary solution: need to figure out how to do this the right way
80+
std::string ewhat = e.what();
81+
if (ewhat.find("OdrStop") != std::string::npos) {
82+
std::cerr << ewhat << std::endl;
83+
*istop = 1;
84+
} else {
85+
throw;
86+
}
87+
}
3088
};
3189

3290
/*
@@ -38,45 +96,46 @@ Some arguments have a default value of `nullptr` — this is by design, as the F
3896
automatically interprets `nullptr` as an absent optional argument. This approach avoids the
3997
redundant definition of default values in multiple places.
4098
*/
41-
int odr_wrapper(int n,
42-
int m,
43-
int q,
44-
int npar,
45-
int ldwe,
46-
int ld2we,
47-
int ldwd,
48-
int ld2wd,
49-
int ldifx,
50-
int ldstpd,
51-
int ldscld,
52-
const nb::callable fcn_f,
53-
const nb::callable fcn_fjacb,
54-
const nb::callable fcn_fjacd,
55-
nb::ndarray<double, nb::c_contig> beta,
56-
nb::ndarray<const double, nb::c_contig> y,
57-
nb::ndarray<const double, nb::c_contig> x,
58-
nb::ndarray<double, nb::c_contig> delta,
59-
std::optional<nb::ndarray<const double, nb::c_contig>> we,
60-
std::optional<nb::ndarray<const double, nb::c_contig>> wd,
61-
std::optional<nb::ndarray<const int, nb::c_contig>> ifixb,
62-
std::optional<nb::ndarray<const int, nb::c_contig>> ifixx,
63-
std::optional<nb::ndarray<const double, nb::c_contig>> stpb,
64-
std::optional<nb::ndarray<const double, nb::c_contig>> stpd,
65-
std::optional<nb::ndarray<const double, nb::c_contig>> sclb,
66-
std::optional<nb::ndarray<const double, nb::c_contig>> scld,
67-
std::optional<nb::ndarray<const double, nb::c_contig>> lower,
68-
std::optional<nb::ndarray<const double, nb::c_contig>> upper,
69-
std::optional<nb::ndarray<double, nb::c_contig>> rwork,
70-
std::optional<nb::ndarray<int, nb::c_contig>> iwork,
71-
std::optional<int> job,
72-
std::optional<int> ndigit,
73-
std::optional<double> taufac,
74-
std::optional<double> sstol,
75-
std::optional<double> partol,
76-
std::optional<int> maxit,
77-
std::optional<int> iprint,
78-
std::optional<std::string> errfile,
79-
std::optional<std::string> rptfile)
99+
int odr_wrapper(
100+
int n,
101+
int m,
102+
int q,
103+
int npar,
104+
int ldwe,
105+
int ld2we,
106+
int ldwd,
107+
int ld2wd,
108+
int ldifx,
109+
int ldstpd,
110+
int ldscld,
111+
const nb::callable fcn_f,
112+
const nb::callable fcn_fjacb,
113+
const nb::callable fcn_fjacd,
114+
nb::ndarray<double, nb::c_contig> beta,
115+
nb::ndarray<const double, nb::c_contig> y,
116+
nb::ndarray<const double, nb::c_contig> x,
117+
nb::ndarray<double, nb::c_contig> delta,
118+
std::optional<nb::ndarray<const double, nb::c_contig>> we,
119+
std::optional<nb::ndarray<const double, nb::c_contig>> wd,
120+
std::optional<nb::ndarray<const int, nb::c_contig>> ifixb,
121+
std::optional<nb::ndarray<const int, nb::c_contig>> ifixx,
122+
std::optional<nb::ndarray<const double, nb::c_contig>> stpb,
123+
std::optional<nb::ndarray<const double, nb::c_contig>> stpd,
124+
std::optional<nb::ndarray<const double, nb::c_contig>> sclb,
125+
std::optional<nb::ndarray<const double, nb::c_contig>> scld,
126+
std::optional<nb::ndarray<const double, nb::c_contig>> lower,
127+
std::optional<nb::ndarray<const double, nb::c_contig>> upper,
128+
std::optional<nb::ndarray<double, nb::c_contig>> rwork,
129+
std::optional<nb::ndarray<int, nb::c_contig>> iwork,
130+
std::optional<int> job,
131+
std::optional<int> ndigit,
132+
std::optional<double> taufac,
133+
std::optional<double> sstol,
134+
std::optional<double> partol,
135+
std::optional<int> maxit,
136+
std::optional<int> iprint,
137+
std::optional<std::string> errfile,
138+
std::optional<std::string> rptfile)
80139

81140
{
82141
// Create pointers to the NumPy arrays and scalar arguments
@@ -116,87 +175,8 @@ int odr_wrapper(int n,
116175
if (rwork) lrwork = rwork.value().size();
117176
if (iwork) liwork = iwork.value().size();
118177

119-
// Build static pointers to the Python functions
120-
// The static variables are necessary to ensure that the Python functions can be accessed
121-
// within the C-style function 'fcn'
122-
static nb::callable fcn_f_holder;
123-
fcn_f_holder = std::move(fcn_f);
124-
auto cleaner_1 = SelfCleaningPyObject(fcn_f_holder);
125-
126-
static nb::callable fcn_fjacb_holder;
127-
fcn_fjacb_holder = std::move(fcn_fjacb);
128-
auto cleaner_2 = SelfCleaningPyObject(fcn_fjacb_holder);
129-
130-
static nb::callable fcn_fjacd_holder;
131-
fcn_fjacd_holder = std::move(fcn_fjacd);
132-
auto cleaner_3 = SelfCleaningPyObject(fcn_fjacd_holder);
133-
134-
// Define the overall user-supplied model function 'fcn'.
135-
// The model function and its Jacobians are still passed through the closure environment.
136-
// In a future version, we should pass them via the thunk argument instead.
137-
odrpack_fcn_t fcn = nullptr;
138-
139-
fcn = [](const int *n_ptr, const int *m_ptr, const int *q_ptr, const int *npar_ptr, const int *ldifx_ptr,
140-
const double beta[], const double xplusd[], const int ifixb[], const int ifixx[],
141-
const int *ideval_ptr, double f[], double fjacb[], double fjacd[], int *istop,
142-
void *thunk) {
143-
// Dereference scalar inputs
144-
auto n = *n_ptr;
145-
auto m = *m_ptr;
146-
auto q = *q_ptr;
147-
auto npar = *npar_ptr;
148-
auto ideval = *ideval_ptr;
149-
150-
// Create NumPy arrays that wrap the input C-style arrays, without copying the data
151-
nb::ndarray<const double, nb::numpy> beta_ndarray(beta, {static_cast<size_t>(npar)});
152-
nb::ndarray<const double, nb::numpy> xplusd_ndarray(
153-
xplusd,
154-
(m == 1) ? std::initializer_list<size_t>{static_cast<size_t>(n)}
155-
: std::initializer_list<size_t>{static_cast<size_t>(m), static_cast<size_t>(n)});
156-
157-
*istop = 0;
158-
try {
159-
// Evaluate model function
160-
if (ideval % 10 > 0) {
161-
auto f_object = fcn_f_holder(xplusd_ndarray, beta_ndarray);
162-
auto f_ndarray = nb::cast<nb::ndarray<const double, nb::c_contig>>(f_object);
163-
auto f_ndarray_ptr = f_ndarray.data();
164-
for (auto i = 0; i < q * n; i++) {
165-
f[i] = f_ndarray_ptr[i];
166-
}
167-
}
168-
169-
// Model partial derivatives wrt `beta`
170-
if ((ideval / 10) % 10 != 0) {
171-
auto fjacb_object = fcn_fjacb_holder(xplusd_ndarray, beta_ndarray);
172-
auto fjacb_ndarray = nb::cast<nb::ndarray<const double, nb::c_contig>>(fjacb_object);
173-
auto fjacb_ndarray_ptr = fjacb_ndarray.data();
174-
for (auto i = 0; i < q * npar * n; i++) {
175-
fjacb[i] = fjacb_ndarray_ptr[i];
176-
}
177-
}
178-
179-
// Model partial derivatives wrt `delta`
180-
if ((ideval / 100) % 10 != 0) {
181-
auto fjacd_object = fcn_fjacd_holder(xplusd_ndarray, beta_ndarray);
182-
auto fjacd_ndarray = nb::cast<nb::ndarray<const double, nb::c_contig>>(fjacd_object);
183-
auto fjacd_ndarray_ptr = fjacd_ndarray.data();
184-
for (auto i = 0; i < q * npar * n; i++) {
185-
fjacd[i] = fjacd_ndarray_ptr[i];
186-
}
187-
}
188-
189-
} catch (const nb::python_error &e) {
190-
// temporary solution: need to figure out how to do this the right way
191-
std::string ewhat = e.what();
192-
if (ewhat.find("OdrStop") != std::string::npos) {
193-
std::cerr << ewhat << std::endl;
194-
*istop = 1;
195-
} else {
196-
throw;
197-
}
198-
}
199-
};
178+
// Define the context for the user-supplied model function and its Jacobians.
179+
Context context = {fcn_f, fcn_fjacb, fcn_fjacd};
200180

201181
// Open files
202182
int lunrpt = 6;
@@ -221,13 +201,14 @@ int odr_wrapper(int n,
221201

222202
// Call the C function
223203
int info = -1;
224-
void *thunk = nullptr;
225-
odr_long_c(fcn, &n, &m, &q, &npar, &ldwe, &ld2we, &ldwd, &ld2wd, &ldifx,
226-
&ldstpd, &ldscld, &lrwork, &liwork, beta_ptr, y_ptr, x_ptr, we_ptr,
227-
wd_ptr, ifixb_ptr, ifixx_ptr, stpb_ptr, stpd_ptr, sclb_ptr,
228-
scld_ptr, delta_ptr, lower_ptr, upper_ptr, rwork_ptr, iwork_ptr,
229-
job_ptr, ndigit_ptr, taufac_ptr, sstol_ptr, partol_ptr, maxit_ptr,
230-
iprint_ptr, &lunerr, &lunrpt, &info, thunk);
204+
odr_long_c(
205+
fcn, static_cast<void *>(&context),
206+
&n, &m, &q, &npar, &ldwe, &ld2we, &ldwd, &ld2wd, &ldifx,
207+
&ldstpd, &ldscld, &lrwork, &liwork, beta_ptr, y_ptr, x_ptr, we_ptr,
208+
wd_ptr, ifixb_ptr, ifixx_ptr, stpb_ptr, stpd_ptr, sclb_ptr,
209+
scld_ptr, delta_ptr, lower_ptr, upper_ptr, rwork_ptr, iwork_ptr,
210+
job_ptr, ndigit_ptr, taufac_ptr, sstol_ptr, partol_ptr, maxit_ptr,
211+
iprint_ptr, &lunerr, &lunrpt, &info);
231212

232213
// Close files
233214
if (rptfile) {

subprojects/odrpack95.wrap

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[wrap-git]
22
url = https://github.com/HugoMVale/odrpack95.git
3-
revision = v2.1.0
3+
revision = v3.0.0
44
# revision = HEAD
55
depth = 1
66

0 commit comments

Comments
 (0)