Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions mypyc/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def build_using_shared_lib(
deps: list[str],
build_dir: str,
extra_compile_args: list[str],
extra_include_dirs: list[str],
) -> list[Extension]:
"""Produce the list of extension modules when a shared library is needed.

Expand All @@ -373,7 +374,7 @@ def build_using_shared_lib(
get_extension()(
shared_lib_name(group_name),
sources=cfiles,
include_dirs=[include_dir(), build_dir],
include_dirs=[include_dir(), build_dir] + extra_include_dirs,
depends=deps,
extra_compile_args=extra_compile_args,
)
Expand All @@ -399,7 +400,10 @@ def build_using_shared_lib(


def build_single_module(
sources: list[BuildSource], cfiles: list[str], extra_compile_args: list[str]
sources: list[BuildSource],
cfiles: list[str],
extra_compile_args: list[str],
extra_include_dirs: list[str],
) -> list[Extension]:
"""Produce the list of extension modules for a standalone extension.

Expand All @@ -409,7 +413,7 @@ def build_single_module(
get_extension()(
sources[0].module,
sources=cfiles,
include_dirs=[include_dir()],
include_dirs=[include_dir()] + extra_include_dirs,
extra_compile_args=extra_compile_args,
)
]
Expand Down Expand Up @@ -513,7 +517,7 @@ def mypyc_build(
*,
separate: bool | list[tuple[list[str], str | None]] = False,
only_compile_paths: Iterable[str] | None = None,
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[tuple[str, list[str]]]] | None = None,
always_use_shared_lib: bool = False,
) -> tuple[emitmodule.Groups, list[tuple[list[str], list[str]]], list[SourceDep]]:
"""Do the front and middle end of mypyc building, producing and writing out C source."""
Expand Down Expand Up @@ -547,7 +551,7 @@ def mypyc_build(
write_file(os.path.join(compiler_options.target_dir, "ops.txt"), ops_text)
else:
group_cfiles = skip_cgen_input[0]
source_deps = [SourceDep(d) for d in skip_cgen_input[1]]
source_deps = [SourceDep(path, include_dirs=dirs) for (path, dirs) in skip_cgen_input[1]]

# Write out the generated C and collect the files for each group
# Should this be here??
Expand Down Expand Up @@ -664,7 +668,7 @@ def mypycify(
strip_asserts: bool = False,
multi_file: bool = False,
separate: bool | list[tuple[list[str], str | None]] = False,
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[tuple[str, list[str]]]] | None = None,
target_dir: str | None = None,
include_runtime_files: bool | None = None,
strict_dunder_typing: bool = False,
Expand Down Expand Up @@ -781,12 +785,14 @@ def mypycify(
# runtime library in. Otherwise it just gets #included to save on
# compiler invocations.
shared_cfilenames = []
include_dirs = set()
if not compiler_options.include_runtime_files:
# Collect all files to copy: runtime files + conditional source files
files_to_copy = list(RUNTIME_C_FILES)
for source_dep in source_deps:
files_to_copy.append(source_dep.path)
files_to_copy.append(source_dep.get_header())
include_dirs.update(source_dep.include_dirs)

# Copy all files
for name in files_to_copy:
Expand All @@ -797,6 +803,7 @@ def mypycify(
shared_cfilenames.append(rt_file)

extensions = []
extra_include_dirs = [os.path.join(include_dir(), dir) for dir in include_dirs]
for (group_sources, lib_name), (cfilenames, deps) in zip(groups, group_cfilenames):
if lib_name:
extensions.extend(
Expand All @@ -807,11 +814,14 @@ def mypycify(
deps,
build_dir,
cflags,
extra_include_dirs,
)
)
else:
extensions.extend(
build_single_module(group_sources, cfilenames + shared_cfilenames, cflags)
build_single_module(
group_sources, cfilenames + shared_cfilenames, cflags, extra_include_dirs
)
)

if install_librt:
Expand Down
26 changes: 12 additions & 14 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ def collect_source_dependencies(modules: dict[str, ModuleIR]) -> set[SourceDep]:
for dep in module.dependencies:
if isinstance(dep, SourceDep):
source_deps.add(dep)
else:
source_deps.add(dep.api_dep())
return source_deps


Expand Down Expand Up @@ -585,6 +587,8 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
source_deps = collect_source_dependencies(self.modules)
for source_dep in sorted(source_deps, key=lambda d: d.path):
base_emitter.emit_line(f'#include "{source_dep.path}"')
if self.compiler_options.depends_on_librt_internal:
base_emitter.emit_line('#include "internal/librt_internal_api.c"')
base_emitter.emit_line(f'#include "__native{self.short_group_suffix}.h"')
base_emitter.emit_line(f'#include "__native_internal{self.short_group_suffix}.h"')
emitter = base_emitter
Expand Down Expand Up @@ -634,26 +638,20 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
ext_declarations.emit_line(f"#define MYPYC_NATIVE{self.group_suffix}_H")
ext_declarations.emit_line("#include <Python.h>")
ext_declarations.emit_line("#include <CPy.h>")
if self.compiler_options.depends_on_librt_internal:
ext_declarations.emit_line("#include <internal/librt_internal.h>")
if any(LIBRT_BASE64 in mod.dependencies for mod in self.modules.values()):
ext_declarations.emit_line("#include <base64/librt_base64.h>")
if any(LIBRT_STRINGS in mod.dependencies for mod in self.modules.values()):
ext_declarations.emit_line("#include <strings/librt_strings.h>")
if any(LIBRT_TIME in mod.dependencies for mod in self.modules.values()):
ext_declarations.emit_line("#include <time/librt_time.h>")
if any(LIBRT_VECS in mod.dependencies for mod in self.modules.values()):
ext_declarations.emit_line("#include <vecs/librt_vecs.h>")
# Include headers for conditional source files
source_deps = collect_source_dependencies(self.modules)
for source_dep in sorted(source_deps, key=lambda d: d.path):
ext_declarations.emit_line(f'#include "{source_dep.get_header()}"')

declarations = Emitter(self.context)
declarations.emit_line(f"#ifndef MYPYC_LIBRT_INTERNAL{self.group_suffix}_H")
declarations.emit_line(f"#define MYPYC_LIBRT_INTERNAL{self.group_suffix}_H")
declarations.emit_line("#include <Python.h>")
declarations.emit_line("#include <CPy.h>")

if self.compiler_options.depends_on_librt_internal:
declarations.emit_line('#include "internal/librt_internal_api.h"')
# Include headers for conditional source files
source_deps = collect_source_dependencies(self.modules)
for source_dep in sorted(source_deps, key=lambda d: d.path):
declarations.emit_line(f'#include "{source_dep.get_header()}"')

declarations.emit_line(f'#include "__native{self.short_group_suffix}.h"')
declarations.emit_line()
declarations.emit_line("int CPyGlobalsInit(void);")
Expand Down
11 changes: 10 additions & 1 deletion mypyc/ir/deps.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Final


Expand All @@ -17,17 +19,24 @@ def __eq__(self, other: object) -> bool:
def __hash__(self) -> int:
return hash(("Capsule", self.name))

def api_dep(self) -> SourceDep:
module = self.name.split(".")[-1]
return SourceDep(f"{module}/librt_{module}_api.c", include_dirs=[module])


class SourceDep:
"""Defines a C source file that a primitive may require.

Each source file must also have a corresponding .h file (replace .c with .h)
that gets implicitly #included if the source is used.
include_dirs are passed to the C compiler when the file is compiled as a
shared library separate from the C extension.
"""

def __init__(self, path: str) -> None:
def __init__(self, path: str, *, include_dirs: list[str] | None = None) -> None:
# Relative path from mypyc/lib-rt, e.g. 'bytes_extra_ops.c'
self.path: Final = path
self.include_dirs: Final = include_dirs or []

def __repr__(self) -> str:
return f"SourceDep(path={self.path!r})"
Expand Down
40 changes: 0 additions & 40 deletions mypyc/lib-rt/base64/librt_base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,4 @@
#define LIBRT_BASE64_API_VERSION 2
#define LIBRT_BASE64_API_LEN 4

static void *LibRTBase64_API[LIBRT_BASE64_API_LEN];

#define LibRTBase64_ABIVersion (*(int (*)(void)) LibRTBase64_API[0])
#define LibRTBase64_APIVersion (*(int (*)(void)) LibRTBase64_API[1])
#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[2])
#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[3])

static int
import_librt_base64(void)
{
PyObject *mod = PyImport_ImportModule("librt.base64");
if (mod == NULL)
return -1;
Py_DECREF(mod); // we import just for the side effect of making the below work.
void *capsule = PyCapsule_Import("librt.base64._C_API", 0);
if (capsule == NULL)
return -1;
memcpy(LibRTBase64_API, capsule, sizeof(LibRTBase64_API));
if (LibRTBase64_ABIVersion() != LIBRT_BASE64_ABI_VERSION) {
char err[128];
snprintf(err, sizeof(err), "ABI version conflict for librt.base64, expected %d, found %d",
LIBRT_BASE64_ABI_VERSION,
LibRTBase64_ABIVersion()
);
PyErr_SetString(PyExc_ValueError, err);
return -1;
}
if (LibRTBase64_APIVersion() < LIBRT_BASE64_API_VERSION) {
char err[128];
snprintf(err, sizeof(err),
"API version conflict for librt.base64, expected %d or newer, found %d (hint: upgrade librt)",
LIBRT_BASE64_API_VERSION,
LibRTBase64_APIVersion()
);
PyErr_SetString(PyExc_ValueError, err);
return -1;
}
return 0;
}

#endif // LIBRT_BASE64_H
36 changes: 36 additions & 0 deletions mypyc/lib-rt/base64/librt_base64_api.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "librt_base64_api.h"

void *LibRTBase64_API[LIBRT_BASE64_API_LEN] = {0};

int
import_librt_base64(void)
{
PyObject *mod = PyImport_ImportModule("librt.base64");
if (mod == NULL)
return -1;
Py_DECREF(mod); // we import just for the side effect of making the below work.
void *capsule = PyCapsule_Import("librt.base64._C_API", 0);
if (capsule == NULL)
return -1;
memcpy(LibRTBase64_API, capsule, sizeof(LibRTBase64_API));
if (LibRTBase64_ABIVersion() != LIBRT_BASE64_ABI_VERSION) {
char err[128];
snprintf(err, sizeof(err), "ABI version conflict for librt.base64, expected %d, found %d",
LIBRT_BASE64_ABI_VERSION,
LibRTBase64_ABIVersion()
);
PyErr_SetString(PyExc_ValueError, err);
return -1;
}
if (LibRTBase64_APIVersion() < LIBRT_BASE64_API_VERSION) {
char err[128];
snprintf(err, sizeof(err),
"API version conflict for librt.base64, expected %d or newer, found %d (hint: upgrade librt)",
LIBRT_BASE64_API_VERSION,
LibRTBase64_APIVersion()
);
PyErr_SetString(PyExc_ValueError, err);
return -1;
}
return 0;
}
15 changes: 15 additions & 0 deletions mypyc/lib-rt/base64/librt_base64_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef LIBRT_BASE64_API_H
#define LIBRT_BASE64_API_H

#include "librt_base64.h"

extern void *LibRTBase64_API[LIBRT_BASE64_API_LEN];

#define LibRTBase64_ABIVersion (*(int (*)(void)) LibRTBase64_API[0])
#define LibRTBase64_APIVersion (*(int (*)(void)) LibRTBase64_API[1])
#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[2])
#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[3])

int import_librt_base64(void);

#endif // LIBRT_BASE64_API_H
2 changes: 1 addition & 1 deletion mypyc/lib-rt/byteswriter_extra_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <Python.h>

#include "mypyc_util.h"
#include "strings/librt_strings.h"
#include "strings/librt_strings_api.h"
#include "strings/librt_strings_common.h"

// BytesWriter: Length and capacity
Expand Down
70 changes: 2 additions & 68 deletions mypyc/lib-rt/internal/librt_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define LIBRT_INTERNAL_H

#include <Python.h>
#include <stdbool.h>

// ABI version -- only an exact match is compatible. This will only be changed in
// very exceptional cases (likely never) due to strict backward compatibility
Expand All @@ -11,7 +12,7 @@
// API version -- more recent versions must maintain backward compatibility, i.e.
// we can add new features but not remove or change existing features (unless
// ABI version is changed, but see the comment above).
#define LIBRT_INTERNAL_API_VERSION 1
#define LIBRT_INTERNAL_API_VERSION 1

// Number of functions in the capsule API. If you add a new function, also increase
// LIBRT_INTERNAL_API_VERSION.
Expand Down Expand Up @@ -43,73 +44,6 @@ static PyTypeObject *WriteBuffer_type_internal(void);
static int NativeInternal_API_Version(void);
static PyObject *extract_symbol_internal(PyObject *data);

#else

static void *NativeInternal_API[LIBRT_INTERNAL_API_LEN];

#define ReadBuffer_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[0])
#define WriteBuffer_internal (*(PyObject* (*)(void)) NativeInternal_API[1])
#define WriteBuffer_getvalue_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[2])
#define write_bool_internal (*(char (*)(PyObject *source, char value)) NativeInternal_API[3])
#define read_bool_internal (*(char (*)(PyObject *source)) NativeInternal_API[4])
#define write_str_internal (*(char (*)(PyObject *source, PyObject *value)) NativeInternal_API[5])
#define read_str_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[6])
#define write_float_internal (*(char (*)(PyObject *source, double value)) NativeInternal_API[7])
#define read_float_internal (*(double (*)(PyObject *source)) NativeInternal_API[8])
#define write_int_internal (*(char (*)(PyObject *source, CPyTagged value)) NativeInternal_API[9])
#define read_int_internal (*(CPyTagged (*)(PyObject *source)) NativeInternal_API[10])
#define write_tag_internal (*(char (*)(PyObject *source, uint8_t value)) NativeInternal_API[11])
#define read_tag_internal (*(uint8_t (*)(PyObject *source)) NativeInternal_API[12])
#define NativeInternal_ABI_Version (*(int (*)(void)) NativeInternal_API[13])
#define write_bytes_internal (*(char (*)(PyObject *source, PyObject *value)) NativeInternal_API[14])
#define read_bytes_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[15])
#define cache_version_internal (*(uint8_t (*)(void)) NativeInternal_API[16])
#define ReadBuffer_type_internal (*(PyTypeObject* (*)(void)) NativeInternal_API[17])
#define WriteBuffer_type_internal (*(PyTypeObject* (*)(void)) NativeInternal_API[18])
#define NativeInternal_API_Version (*(int (*)(void)) NativeInternal_API[19])
#define extract_symbol_internal (*(PyObject* (*)(PyObject *source)) NativeInternal_API[20])

static int
import_librt_internal(void)
{
PyObject *mod = PyImport_ImportModule("librt.internal");
if (mod == NULL)
return -1;
Py_DECREF(mod); // we import just for the side effect of making the below work.
void *capsule = PyCapsule_Import("librt.internal._C_API", 0);
if (capsule == NULL)
return -1;
memcpy(NativeInternal_API, capsule, sizeof(NativeInternal_API));
if (NativeInternal_ABI_Version() != LIBRT_INTERNAL_ABI_VERSION) {
char err[128];
snprintf(err, sizeof(err), "ABI version conflict for librt.internal, expected %d, found %d",
LIBRT_INTERNAL_ABI_VERSION,
NativeInternal_ABI_Version()
);
PyErr_SetString(PyExc_ValueError, err);
return -1;
}
if (NativeInternal_API_Version() < LIBRT_INTERNAL_API_VERSION) {
char err[128];
snprintf(err, sizeof(err),
"API version conflict for librt.internal, expected %d or newer, found %d (hint: upgrade librt)",
LIBRT_INTERNAL_API_VERSION,
NativeInternal_API_Version()
);
PyErr_SetString(PyExc_ValueError, err);
return -1;
}
return 0;
}

#endif

static inline bool CPyReadBuffer_Check(PyObject *obj) {
return Py_TYPE(obj) == ReadBuffer_type_internal();
}

static inline bool CPyWriteBuffer_Check(PyObject *obj) {
return Py_TYPE(obj) == WriteBuffer_type_internal();
}

#endif // LIBRT_INTERNAL_H
Loading
Loading