44#include < nanobind/stl/optional.h>
55#include < nanobind/stl/string.h>
66
7+ #include < algorithm>
78#include < array>
89#include < iostream>
910#include < map>
@@ -17,16 +18,73 @@ namespace nb = nanobind;
1718using namespace nanobind ::literals;
1819
1920/*
20- The following class is necessary to ensure that the static variables used to store the callback
21- functions are automatically reset upon normal or abnormal exit from the `odr_wrapper` function.
22- From: https://github.com/libprima/prima/blob/main/python/_prima.cpp
21+ A container used to pass around the model function and its jacobians without creating a closure.
2322*/
24- class SelfCleaningPyObject {
25- nb::object &obj;
23+ struct Context {
24+ nb::callable fcn_f;
25+ nb::callable fcn_fjacb;
26+ nb::callable fcn_fjacd;
27+ };
28+
29+ /*
30+ Callback function invoked from `odrpack` to evaluate the model function and, optionally, its
31+ Jacobians. The actual python functions are stored in the `Context` struct, which is passed through
32+ the void pointer argument `data`. This is used to call the appropriate functions without
33+ creating closures.
34+ */
35+ void fcn (const int *n_ptr, const int *m_ptr, const int *q_ptr, const int *npar_ptr, const int *ldifx_ptr,
36+ const double beta[], const double xplusd[], const int ifixb[], const int ifixx[],
37+ const int *ideval_ptr, double f[], double fjacb[], double fjacd[], int *istop,
38+ void *data) {
39+ // Dereference/cast scalar inputs
40+ auto n = static_cast <size_t >(*n_ptr);
41+ auto m = static_cast <size_t >(*m_ptr);
42+ auto q = static_cast <size_t >(*q_ptr);
43+ auto npar = static_cast <size_t >(*npar_ptr);
44+ auto ideval = *ideval_ptr;
45+
46+ // Create NumPy arrays that wrap the input C-style arrays, without copying the data
47+ nb::ndarray<const double , nb::numpy> beta_ndarray (beta, {npar});
48+ nb::ndarray<const double , nb::numpy> xplusd_ndarray (
49+ xplusd,
50+ (m == 1 ) ? std::initializer_list<size_t >{n} : std::initializer_list<size_t >{m, n});
51+
52+ // Retrieve the model context
53+ const auto *context = static_cast <const Context *>(data);
54+
55+ *istop = 0 ;
56+ try {
57+ // Evaluate model function
58+ if (ideval % 10 > 0 ) {
59+ auto f_pyobject = context->fcn_f (xplusd_ndarray, beta_ndarray);
60+ auto f_ndarray = nb::cast<nb::ndarray<const double , nb::c_contig>>(f_pyobject);
61+ std::copy_n (f_ndarray.data (), q * n, f);
62+ }
63+
64+ // Model partial derivatives wrt `beta`
65+ if ((ideval / 10 ) % 10 != 0 ) {
66+ auto fjacb_pyobject = context->fcn_fjacb (xplusd_ndarray, beta_ndarray);
67+ auto fjacb_ndarray = nb::cast<nb::ndarray<const double , nb::c_contig>>(fjacb_pyobject);
68+ std::copy_n (fjacb_ndarray.data (), q * npar * n, fjacb);
69+ }
2670
27- public:
28- SelfCleaningPyObject (nb::object &obj) : obj(obj) {}
29- ~SelfCleaningPyObject () { obj = nb::none (); }
71+ // Model partial derivatives wrt `delta`
72+ if ((ideval / 100 ) % 10 != 0 ) {
73+ auto fjacd_pyobject = context->fcn_fjacd (xplusd_ndarray, beta_ndarray);
74+ auto fjacd_ndarray = nb::cast<nb::ndarray<const double , nb::c_contig>>(fjacd_pyobject);
75+ std::copy_n (fjacd_ndarray.data (), q * m * n, fjacd);
76+ }
77+
78+ } catch (const nb::python_error &e) {
79+ // temporary solution: need to figure out how to do this the right way
80+ std::string ewhat = e.what ();
81+ if (ewhat.find (" OdrStop" ) != std::string::npos) {
82+ std::cerr << ewhat << std::endl;
83+ *istop = 1 ;
84+ } else {
85+ throw ;
86+ }
87+ }
3088};
3189
3290/*
@@ -38,45 +96,46 @@ Some arguments have a default value of `nullptr` — this is by design, as the F
3896automatically interprets `nullptr` as an absent optional argument. This approach avoids the
3997redundant definition of default values in multiple places.
4098*/
41- int odr_wrapper (int n,
42- int m,
43- int q,
44- int npar,
45- int ldwe,
46- int ld2we,
47- int ldwd,
48- int ld2wd,
49- int ldifx,
50- int ldstpd,
51- int ldscld,
52- const nb::callable fcn_f,
53- const nb::callable fcn_fjacb,
54- const nb::callable fcn_fjacd,
55- nb::ndarray<double , nb::c_contig> beta,
56- nb::ndarray<const double , nb::c_contig> y,
57- nb::ndarray<const double , nb::c_contig> x,
58- nb::ndarray<double , nb::c_contig> delta,
59- std::optional<nb::ndarray<const double , nb::c_contig>> we,
60- std::optional<nb::ndarray<const double , nb::c_contig>> wd,
61- std::optional<nb::ndarray<const int , nb::c_contig>> ifixb,
62- std::optional<nb::ndarray<const int , nb::c_contig>> ifixx,
63- std::optional<nb::ndarray<const double , nb::c_contig>> stpb,
64- std::optional<nb::ndarray<const double , nb::c_contig>> stpd,
65- std::optional<nb::ndarray<const double , nb::c_contig>> sclb,
66- std::optional<nb::ndarray<const double , nb::c_contig>> scld,
67- std::optional<nb::ndarray<const double , nb::c_contig>> lower,
68- std::optional<nb::ndarray<const double , nb::c_contig>> upper,
69- std::optional<nb::ndarray<double , nb::c_contig>> rwork,
70- std::optional<nb::ndarray<int , nb::c_contig>> iwork,
71- std::optional<int > job,
72- std::optional<int > ndigit,
73- std::optional<double > taufac,
74- std::optional<double > sstol,
75- std::optional<double > partol,
76- std::optional<int > maxit,
77- std::optional<int > iprint,
78- std::optional<std::string> errfile,
79- std::optional<std::string> rptfile)
99+ int odr_wrapper (
100+ int n,
101+ int m,
102+ int q,
103+ int npar,
104+ int ldwe,
105+ int ld2we,
106+ int ldwd,
107+ int ld2wd,
108+ int ldifx,
109+ int ldstpd,
110+ int ldscld,
111+ const nb::callable fcn_f,
112+ const nb::callable fcn_fjacb,
113+ const nb::callable fcn_fjacd,
114+ nb::ndarray<double , nb::c_contig> beta,
115+ nb::ndarray<const double , nb::c_contig> y,
116+ nb::ndarray<const double , nb::c_contig> x,
117+ nb::ndarray<double , nb::c_contig> delta,
118+ std::optional<nb::ndarray<const double , nb::c_contig>> we,
119+ std::optional<nb::ndarray<const double , nb::c_contig>> wd,
120+ std::optional<nb::ndarray<const int , nb::c_contig>> ifixb,
121+ std::optional<nb::ndarray<const int , nb::c_contig>> ifixx,
122+ std::optional<nb::ndarray<const double , nb::c_contig>> stpb,
123+ std::optional<nb::ndarray<const double , nb::c_contig>> stpd,
124+ std::optional<nb::ndarray<const double , nb::c_contig>> sclb,
125+ std::optional<nb::ndarray<const double , nb::c_contig>> scld,
126+ std::optional<nb::ndarray<const double , nb::c_contig>> lower,
127+ std::optional<nb::ndarray<const double , nb::c_contig>> upper,
128+ std::optional<nb::ndarray<double , nb::c_contig>> rwork,
129+ std::optional<nb::ndarray<int , nb::c_contig>> iwork,
130+ std::optional<int > job,
131+ std::optional<int > ndigit,
132+ std::optional<double > taufac,
133+ std::optional<double > sstol,
134+ std::optional<double > partol,
135+ std::optional<int > maxit,
136+ std::optional<int > iprint,
137+ std::optional<std::string> errfile,
138+ std::optional<std::string> rptfile)
80139
81140{
82141 // Create pointers to the NumPy arrays and scalar arguments
@@ -116,87 +175,8 @@ int odr_wrapper(int n,
116175 if (rwork) lrwork = rwork.value ().size ();
117176 if (iwork) liwork = iwork.value ().size ();
118177
119- // Build static pointers to the Python functions
120- // The static variables are necessary to ensure that the Python functions can be accessed
121- // within the C-style function 'fcn'
122- static nb::callable fcn_f_holder;
123- fcn_f_holder = std::move (fcn_f);
124- auto cleaner_1 = SelfCleaningPyObject (fcn_f_holder);
125-
126- static nb::callable fcn_fjacb_holder;
127- fcn_fjacb_holder = std::move (fcn_fjacb);
128- auto cleaner_2 = SelfCleaningPyObject (fcn_fjacb_holder);
129-
130- static nb::callable fcn_fjacd_holder;
131- fcn_fjacd_holder = std::move (fcn_fjacd);
132- auto cleaner_3 = SelfCleaningPyObject (fcn_fjacd_holder);
133-
134- // Define the overall user-supplied model function 'fcn'.
135- // The model function and its Jacobians are still passed through the closure environment.
136- // In a future version, we should pass them via the thunk argument instead.
137- odrpack_fcn_t fcn = nullptr ;
138-
139- fcn = [](const int *n_ptr, const int *m_ptr, const int *q_ptr, const int *npar_ptr, const int *ldifx_ptr,
140- const double beta[], const double xplusd[], const int ifixb[], const int ifixx[],
141- const int *ideval_ptr, double f[], double fjacb[], double fjacd[], int *istop,
142- void *thunk) {
143- // Dereference scalar inputs
144- auto n = *n_ptr;
145- auto m = *m_ptr;
146- auto q = *q_ptr;
147- auto npar = *npar_ptr;
148- auto ideval = *ideval_ptr;
149-
150- // Create NumPy arrays that wrap the input C-style arrays, without copying the data
151- nb::ndarray<const double , nb::numpy> beta_ndarray (beta, {static_cast <size_t >(npar)});
152- nb::ndarray<const double , nb::numpy> xplusd_ndarray (
153- xplusd,
154- (m == 1 ) ? std::initializer_list<size_t >{static_cast <size_t >(n)}
155- : std::initializer_list<size_t >{static_cast <size_t >(m), static_cast <size_t >(n)});
156-
157- *istop = 0 ;
158- try {
159- // Evaluate model function
160- if (ideval % 10 > 0 ) {
161- auto f_object = fcn_f_holder (xplusd_ndarray, beta_ndarray);
162- auto f_ndarray = nb::cast<nb::ndarray<const double , nb::c_contig>>(f_object);
163- auto f_ndarray_ptr = f_ndarray.data ();
164- for (auto i = 0 ; i < q * n; i++) {
165- f[i] = f_ndarray_ptr[i];
166- }
167- }
168-
169- // Model partial derivatives wrt `beta`
170- if ((ideval / 10 ) % 10 != 0 ) {
171- auto fjacb_object = fcn_fjacb_holder (xplusd_ndarray, beta_ndarray);
172- auto fjacb_ndarray = nb::cast<nb::ndarray<const double , nb::c_contig>>(fjacb_object);
173- auto fjacb_ndarray_ptr = fjacb_ndarray.data ();
174- for (auto i = 0 ; i < q * npar * n; i++) {
175- fjacb[i] = fjacb_ndarray_ptr[i];
176- }
177- }
178-
179- // Model partial derivatives wrt `delta`
180- if ((ideval / 100 ) % 10 != 0 ) {
181- auto fjacd_object = fcn_fjacd_holder (xplusd_ndarray, beta_ndarray);
182- auto fjacd_ndarray = nb::cast<nb::ndarray<const double , nb::c_contig>>(fjacd_object);
183- auto fjacd_ndarray_ptr = fjacd_ndarray.data ();
184- for (auto i = 0 ; i < q * npar * n; i++) {
185- fjacd[i] = fjacd_ndarray_ptr[i];
186- }
187- }
188-
189- } catch (const nb::python_error &e) {
190- // temporary solution: need to figure out how to do this the right way
191- std::string ewhat = e.what ();
192- if (ewhat.find (" OdrStop" ) != std::string::npos) {
193- std::cerr << ewhat << std::endl;
194- *istop = 1 ;
195- } else {
196- throw ;
197- }
198- }
199- };
178+ // Define the context for the user-supplied model function and its Jacobians.
179+ Context context = {fcn_f, fcn_fjacb, fcn_fjacd};
200180
201181 // Open files
202182 int lunrpt = 6 ;
@@ -221,13 +201,14 @@ int odr_wrapper(int n,
221201
222202 // Call the C function
223203 int info = -1 ;
224- void *thunk = nullptr ;
225- odr_long_c (fcn, &n, &m, &q, &npar, &ldwe, &ld2we, &ldwd, &ld2wd, &ldifx,
226- &ldstpd, &ldscld, &lrwork, &liwork, beta_ptr, y_ptr, x_ptr, we_ptr,
227- wd_ptr, ifixb_ptr, ifixx_ptr, stpb_ptr, stpd_ptr, sclb_ptr,
228- scld_ptr, delta_ptr, lower_ptr, upper_ptr, rwork_ptr, iwork_ptr,
229- job_ptr, ndigit_ptr, taufac_ptr, sstol_ptr, partol_ptr, maxit_ptr,
230- iprint_ptr, &lunerr, &lunrpt, &info, thunk);
204+ odr_long_c (
205+ fcn, static_cast <void *>(&context),
206+ &n, &m, &q, &npar, &ldwe, &ld2we, &ldwd, &ld2wd, &ldifx,
207+ &ldstpd, &ldscld, &lrwork, &liwork, beta_ptr, y_ptr, x_ptr, we_ptr,
208+ wd_ptr, ifixb_ptr, ifixx_ptr, stpb_ptr, stpd_ptr, sclb_ptr,
209+ scld_ptr, delta_ptr, lower_ptr, upper_ptr, rwork_ptr, iwork_ptr,
210+ job_ptr, ndigit_ptr, taufac_ptr, sstol_ptr, partol_ptr, maxit_ptr,
211+ iprint_ptr, &lunerr, &lunrpt, &info);
231212
232213 // Close files
233214 if (rptfile) {
0 commit comments