Skip to content

Commit 2eab7fe

Browse files
committed
Fix __main__ namespace lookup in py_context:call
When using py_context:exec to define a function and then py_context:call with '__main__' as the module, the function was not found because '__main__' was looked up via PyImport_ImportModule (real Python module) instead of ctx->globals where exec stores defined functions. Add special handling to nif_context_call and owngil_execute_call to check ctx->globals first when module is '__main__', matching the existing behavior in nif_context_call_with_env.
1 parent 84325aa commit 2eab7fe

1 file changed

Lines changed: 53 additions & 25 deletions

File tree

c_src/py_nif.c

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2475,27 +2475,42 @@ static void owngil_execute_call(py_context_t *ctx) {
24752475
return;
24762476
}
24772477

2478-
/* Get or import module */
2479-
PyObject *module = context_get_module(ctx, module_name);
2480-
if (module == NULL) {
2481-
ctx->response_term = make_py_error(ctx->shared_env);
2482-
ctx->response_ok = false;
2483-
enif_free(module_name);
2484-
enif_free(func_name_str);
2485-
return;
2486-
}
2478+
PyObject *module = NULL;
2479+
PyObject *func = NULL;
24872480

2488-
/* Get function */
2489-
PyObject *func = PyObject_GetAttrString(module, func_name_str);
2490-
enif_free(module_name);
2491-
enif_free(func_name_str);
2481+
/* Special handling for __main__ module - check ctx->globals first */
2482+
if (strcmp(module_name, "__main__") == 0) {
2483+
func = PyDict_GetItemString(ctx->globals, func_name_str); /* Borrowed ref */
2484+
if (func != NULL) {
2485+
Py_INCREF(func);
2486+
}
2487+
}
24922488

24932489
if (func == NULL) {
2494-
ctx->response_term = make_py_error(ctx->shared_env);
2495-
ctx->response_ok = false;
2496-
return;
2490+
/* Get or import module */
2491+
module = context_get_module(ctx, module_name);
2492+
if (module == NULL) {
2493+
ctx->response_term = make_py_error(ctx->shared_env);
2494+
ctx->response_ok = false;
2495+
enif_free(module_name);
2496+
enif_free(func_name_str);
2497+
return;
2498+
}
2499+
2500+
/* Get function */
2501+
func = PyObject_GetAttrString(module, func_name_str);
2502+
if (func == NULL) {
2503+
ctx->response_term = make_py_error(ctx->shared_env);
2504+
ctx->response_ok = false;
2505+
enif_free(module_name);
2506+
enif_free(func_name_str);
2507+
return;
2508+
}
24972509
}
24982510

2511+
enif_free(module_name);
2512+
enif_free(func_name_str);
2513+
24992514
/* Convert args */
25002515
unsigned int args_len;
25012516
if (!enif_get_list_length(ctx->shared_env, args_term, &args_len)) {
@@ -4251,18 +4266,31 @@ static ERL_NIF_TERM nif_context_call(ErlNifEnv *env, int argc, const ERL_NIF_TER
42514266
bool prev_allow_suspension = tl_allow_suspension;
42524267
tl_allow_suspension = true;
42534268

4254-
/* Get or import module */
4255-
PyObject *module = context_get_module(ctx, module_name);
4256-
if (module == NULL) {
4257-
result = make_py_error(env);
4258-
goto cleanup;
4269+
PyObject *module = NULL;
4270+
PyObject *func = NULL;
4271+
4272+
/* Special handling for __main__ module - check ctx->globals first */
4273+
if (strcmp(module_name, "__main__") == 0) {
4274+
func = PyDict_GetItemString(ctx->globals, func_name); /* Borrowed ref */
4275+
if (func != NULL) {
4276+
Py_INCREF(func);
4277+
}
42594278
}
42604279

4261-
/* Get function */
4262-
PyObject *func = PyObject_GetAttrString(module, func_name);
42634280
if (func == NULL) {
4264-
result = make_py_error(env);
4265-
goto cleanup;
4281+
/* Get or import module */
4282+
module = context_get_module(ctx, module_name);
4283+
if (module == NULL) {
4284+
result = make_py_error(env);
4285+
goto cleanup;
4286+
}
4287+
4288+
/* Get function */
4289+
func = PyObject_GetAttrString(module, func_name);
4290+
if (func == NULL) {
4291+
result = make_py_error(env);
4292+
goto cleanup;
4293+
}
42664294
}
42674295

42684296
/* Convert args */

0 commit comments

Comments
 (0)