Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions scripts/cache_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
"pyarrow.decimal32",
"pyarrow.decimal64",
"pyarrow.decimal128"
]
],
"required": false
},
"pyarrow.dataset": {
"type": "module",
Expand All @@ -29,7 +30,8 @@
"children": [
"pyarrow.dataset.Scanner",
"pyarrow.dataset.Dataset"
]
],
"required": false
},
"pyarrow.dataset.Scanner": {
"type": "attribute",
Expand Down
108 changes: 80 additions & 28 deletions src/duckdb_py/arrow/arrow_array_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,45 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
py::handle arrow_obj_handle(factory->arrow_object);
auto arrow_object_type = DuckDBPyConnection::GetArrowType(arrow_obj_handle);

if (arrow_object_type == PyArrowObjectType::PyCapsuleInterface) {
py::object capsule_obj = arrow_obj_handle.attr("__arrow_c_stream__")();
auto capsule = py::reinterpret_borrow<py::capsule>(capsule_obj);
auto stream = capsule.get_pointer<struct ArrowArrayStream>();
if (!stream->release) {
throw InvalidInputException(
"The __arrow_c_stream__() method returned a released stream. "
"If this object is single-use, implement __arrow_c_schema__() or expose a .schema attribute "
"with _export_to_c() so that DuckDB can extract the schema without consuming the stream.");
}

auto &import_cache_check = *DuckDBPyConnection::ImportCache();
if (import_cache_check.pyarrow.dataset()) {
// Tier A: full pushdown via pyarrow.dataset
// Import as RecordBatchReader, feed through Scanner.from_batches for projection/filter pushdown.
auto pyarrow_lib_module = py::module::import("pyarrow").attr("lib");
auto import_func = pyarrow_lib_module.attr("RecordBatchReader").attr("_import_from_c");
py::object reader = import_func(reinterpret_cast<uint64_t>(stream));
// _import_from_c takes ownership of the stream; null out to prevent capsule double-free
stream->release = nullptr;
auto &import_cache = *DuckDBPyConnection::ImportCache();
py::object arrow_batch_scanner = import_cache.pyarrow.dataset.Scanner().attr("from_batches");
py::handle reader_handle = reader;
auto scanner = ProduceScanner(arrow_batch_scanner, reader_handle, parameters, factory->client_properties);
auto record_batches = scanner.attr("to_reader")();
auto res = make_uniq<ArrowArrayStreamWrapper>();
auto export_to_c = record_batches.attr("_export_to_c");
export_to_c(reinterpret_cast<uint64_t>(&res->arrow_array_stream));
return res;
} else {
// Tier B: no pyarrow.dataset, return raw stream (no pushdown)
// DuckDB applies projection/filter post-scan via arrow_scan_dumb
auto res = make_uniq<ArrowArrayStreamWrapper>();
res->arrow_array_stream = *stream;
stream->release = nullptr;
return res;
}
}

if (arrow_object_type == PyArrowObjectType::PyCapsule) {
auto res = make_uniq<ArrowArrayStreamWrapper>();
auto capsule = py::reinterpret_borrow<py::capsule>(arrow_obj_handle);
Expand All @@ -78,21 +117,12 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
return res;
}

// Scanner and Dataset: require pyarrow.dataset for pushdown
VerifyArrowDatasetLoaded();
auto &import_cache = *DuckDBPyConnection::ImportCache();
py::object scanner;
py::object arrow_batch_scanner = import_cache.pyarrow.dataset.Scanner().attr("from_batches");
switch (arrow_object_type) {
case PyArrowObjectType::Table: {
auto arrow_dataset = import_cache.pyarrow.dataset().attr("dataset");
auto dataset = arrow_dataset(arrow_obj_handle);
py::object arrow_scanner = dataset.attr("__class__").attr("scanner");
scanner = ProduceScanner(arrow_scanner, dataset, parameters, factory->client_properties);
break;
}
case PyArrowObjectType::RecordBatchReader: {
scanner = ProduceScanner(arrow_batch_scanner, arrow_obj_handle, parameters, factory->client_properties);
break;
}
case PyArrowObjectType::Scanner: {
// If it's a scanner we have to turn it to a record batch reader, and then a scanner again since we can't stack
// scanners on arrow Otherwise pushed-down projections and filters will disappear like tears in the rain
Expand All @@ -119,37 +149,29 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
}

void PythonTableArrowArrayStreamFactory::GetSchemaInternal(py::handle arrow_obj_handle, ArrowSchemaWrapper &schema) {
// PyCapsule (from bare capsule Produce path)
if (py::isinstance<py::capsule>(arrow_obj_handle)) {
auto capsule = py::reinterpret_borrow<py::capsule>(arrow_obj_handle);
auto stream = capsule.get_pointer<struct ArrowArrayStream>();
if (!stream->release) {
throw InvalidInputException("This ArrowArrayStream has already been consumed and cannot be scanned again.");
}
stream->get_schema(stream, &schema.arrow_schema);
return;
}

auto table_class = py::module::import("pyarrow").attr("Table");
if (py::isinstance(arrow_obj_handle, table_class)) {
auto obj_schema = arrow_obj_handle.attr("schema");
auto export_to_c = obj_schema.attr("_export_to_c");
export_to_c(reinterpret_cast<uint64_t>(&schema.arrow_schema));
if (stream->get_schema(stream, &schema.arrow_schema)) {
throw InvalidInputException("Failed to get Arrow schema from stream: %s",
stream->get_last_error ? stream->get_last_error(stream) : "unknown error");
}
return;
}

// Scanner: use projected_schema; everything else (RecordBatchReader, Dataset): use .schema
VerifyArrowDatasetLoaded();

auto &import_cache = *DuckDBPyConnection::ImportCache();
auto scanner_class = import_cache.pyarrow.dataset.Scanner();

if (py::isinstance(arrow_obj_handle, scanner_class)) {
if (py::isinstance(arrow_obj_handle, import_cache.pyarrow.dataset.Scanner())) {
auto obj_schema = arrow_obj_handle.attr("projected_schema");
auto export_to_c = obj_schema.attr("_export_to_c");
export_to_c(reinterpret_cast<uint64_t>(&schema));
obj_schema.attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema.arrow_schema));
} else {
auto obj_schema = arrow_obj_handle.attr("schema");
auto export_to_c = obj_schema.attr("_export_to_c");
export_to_c(reinterpret_cast<uint64_t>(&schema));
obj_schema.attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema.arrow_schema));
}
}

Expand All @@ -158,6 +180,36 @@ void PythonTableArrowArrayStreamFactory::GetSchema(uintptr_t factory_ptr, ArrowS
auto factory = static_cast<PythonTableArrowArrayStreamFactory *>(reinterpret_cast<void *>(factory_ptr)); // NOLINT
D_ASSERT(factory->arrow_object);
py::handle arrow_obj_handle(factory->arrow_object);

auto type = DuckDBPyConnection::GetArrowType(arrow_obj_handle);
if (type == PyArrowObjectType::PyCapsuleInterface) {
// Get __arrow_c_schema__ if it exists
if (py::hasattr(arrow_obj_handle, "__arrow_c_schema__")) {
auto schema_capsule = arrow_obj_handle.attr("__arrow_c_schema__")();
auto capsule = py::reinterpret_borrow<py::capsule>(schema_capsule);
auto arrow_schema = capsule.get_pointer<struct ArrowSchema>();
schema.arrow_schema = *arrow_schema;
arrow_schema->release = nullptr; // take ownership
return;
}
// Otherwise try to use .schema with _export_to_c
if (py::hasattr(arrow_obj_handle, "schema")) {
auto obj_schema = arrow_obj_handle.attr("schema");
if (py::hasattr(obj_schema, "_export_to_c")) {
obj_schema.attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema.arrow_schema));
return;
}
}
// Fallback: create a temporary stream just for the schema (consumes single-use streams!)
auto stream_capsule = arrow_obj_handle.attr("__arrow_c_stream__")();
auto capsule = py::reinterpret_borrow<py::capsule>(stream_capsule);
auto stream = capsule.get_pointer<struct ArrowArrayStream>();
if (stream->get_schema(stream, &schema.arrow_schema)) {
throw InvalidInputException("Failed to get Arrow schema from stream: %s",
stream->get_last_error ? stream->get_last_error(stream) : "unknown error");
}
return; // stream_capsule goes out of scope, stream released by capsule destructor
}
GetSchemaInternal(arrow_obj_handle, schema);
}

Expand Down
11 changes: 1 addition & 10 deletions src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,7 @@ class Table : public py::object {

} // namespace pyarrow

enum class PyArrowObjectType {
Invalid,
Table,
RecordBatchReader,
Scanner,
Dataset,
PyCapsule,
PyCapsuleInterface,
MessageReader
};
enum class PyArrowObjectType { Invalid, Table, Scanner, Dataset, PyCapsule, PyCapsuleInterface, MessageReader };

void TransformDuckToArrowChunk(ArrowSchema &arrow_schema, ArrowArray &data, py::list &batches);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ struct PyarrowDatasetCacheItem : public PythonImportCacheItem {

PythonImportCacheItem Scanner;
PythonImportCacheItem Dataset;

protected:
bool IsRequired() const override final {
return false;
}
};

struct PyarrowCacheItem : public PythonImportCacheItem {
Expand Down Expand Up @@ -80,6 +85,11 @@ struct PyarrowCacheItem : public PythonImportCacheItem {
PythonImportCacheItem decimal32;
PythonImportCacheItem decimal64;
PythonImportCacheItem decimal128;

protected:
bool IsRequired() const override final {
return false;
}
};

} // namespace duckdb
20 changes: 5 additions & 15 deletions src/duckdb_py/pyconnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2383,26 +2383,16 @@ PyArrowObjectType DuckDBPyConnection::GetArrowType(const py::handle &obj) {

if (ModuleIsLoaded<PyarrowCacheItem>()) {
auto &import_cache = *DuckDBPyConnection::ImportCache();
// First Verify Lib Types
auto table_class = import_cache.pyarrow.Table();
auto record_batch_reader_class = import_cache.pyarrow.RecordBatchReader();
auto message_reader_class = import_cache.pyarrow.ipc.MessageReader();
if (py::isinstance(obj, table_class)) {
return PyArrowObjectType::Table;
} else if (py::isinstance(obj, record_batch_reader_class)) {
return PyArrowObjectType::RecordBatchReader;
} else if (py::isinstance(obj, message_reader_class)) {
// MessageReader requires nanoarrow, separate scan function
if (py::isinstance(obj, import_cache.pyarrow.ipc.MessageReader())) {
return PyArrowObjectType::MessageReader;
}

if (ModuleIsLoaded<PyarrowDatasetCacheItem>()) {
// Then Verify dataset types
auto dataset_class = import_cache.pyarrow.dataset.Dataset();
auto scanner_class = import_cache.pyarrow.dataset.Scanner();

if (py::isinstance(obj, scanner_class)) {
// Scanner/Dataset don't have __arrow_c_stream__, need dedicated handling
if (py::isinstance(obj, import_cache.pyarrow.dataset.Scanner())) {
return PyArrowObjectType::Scanner;
} else if (py::isinstance(obj, dataset_class)) {
} else if (py::isinstance(obj, import_cache.pyarrow.dataset.Dataset())) {
return PyArrowObjectType::Dataset;
}
}
Expand Down
24 changes: 15 additions & 9 deletions src/duckdb_py/python_replacement_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR
auto dependency_item = PythonDependencyItem::Create(stream_messages);
external_dependency->AddDependency("replacement_cache", std::move(dependency_item));
} else {
if (type == PyArrowObjectType::PyCapsuleInterface) {
entry = entry.attr("__arrow_c_stream__")();
type = PyArrowObjectType::PyCapsule;
}

auto stream_factory = make_uniq<PythonTableArrowArrayStreamFactory>(entry.ptr(), client_properties);
auto stream_factory_produce = PythonTableArrowArrayStreamFactory::Produce;
auto stream_factory_get_schema = PythonTableArrowArrayStreamFactory::GetSchema;
Expand All @@ -66,8 +61,17 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR
make_uniq<ConstantExpression>(Value::POINTER(CastPointerToValue(stream_factory_get_schema))));

if (type == PyArrowObjectType::PyCapsule) {
// Disable projection+filter pushdown
// Disable projection+filter pushdown for bare capsules (single-use, no PyArrow wrapper)
table_function.function = make_uniq<FunctionExpression>("arrow_scan_dumb", std::move(children));
} else if (type == PyArrowObjectType::PyCapsuleInterface) {
// Try to load pyarrow.dataset for pushdown support
auto &cache = *DuckDBPyConnection::ImportCache();
if (!cache.pyarrow.dataset()) {
// No pyarrow.dataset: scan without pushdown, DuckDB handles projection/filter post-scan
table_function.function = make_uniq<FunctionExpression>("arrow_scan_dumb", std::move(children));
} else {
table_function.function = make_uniq<FunctionExpression>("arrow_scan", std::move(children));
}
} else {
table_function.function = make_uniq<FunctionExpression>("arrow_scan", std::move(children));
}
Expand Down Expand Up @@ -141,6 +145,9 @@ unique_ptr<TableRef> PythonReplacementScan::TryReplacementObject(const py::objec
subquery->external_dependency = std::move(dependency);
return std::move(subquery);
} else if (PolarsDataFrame::IsDataFrame(entry)) {
// Polars DataFrames always go through one-time .to_arrow() materialization.
// Polars's __arrow_c_stream__() serializes from its internal layout on every call,
// which is expensive for repeated scans. The .to_arrow() path converts once.
auto arrow_dataset = entry.attr("to_arrow")();
CreateArrowScan(name, arrow_dataset, *table_function, children, client_properties, PyArrowObjectType::Table,
*context.db);
Expand All @@ -149,9 +156,8 @@ unique_ptr<TableRef> PythonReplacementScan::TryReplacementObject(const py::objec
auto arrow_dataset = materialized.attr("to_arrow")();
CreateArrowScan(name, arrow_dataset, *table_function, children, client_properties, PyArrowObjectType::Table,
*context.db);
} else if (DuckDBPyConnection::GetArrowType(entry) != PyArrowObjectType::Invalid &&
!(DuckDBPyConnection::GetArrowType(entry) == PyArrowObjectType::MessageReader && !relation)) {
arrow_type = DuckDBPyConnection::GetArrowType(entry);
} else if ((arrow_type = DuckDBPyConnection::GetArrowType(entry)) != PyArrowObjectType::Invalid &&
!(arrow_type == PyArrowObjectType::MessageReader && !relation)) {
CreateArrowScan(name, entry, *table_function, children, client_properties, arrow_type, *context.db);
} else if (DuckDBPyConnection::IsAcceptedNumpyObject(entry) != NumpyObjectType::INVALID) {
numpytype = DuckDBPyConnection::IsAcceptedNumpyObject(entry);
Expand Down
9 changes: 6 additions & 3 deletions tests/fast/arrow/test_arrow_pycapsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,24 @@ def __arrow_c_stream__(self, requested_schema=None) -> object:
obj = MyObject(df)

# Call the __arrow_c_stream__ from within DuckDB
# MyObject has no __arrow_c_schema__, so GetSchema() falls back to __arrow_c_stream__ (1 call),
# then Produce() calls __arrow_c_stream__ again (1 call) = 2 calls minimum per scan.
res = duckdb_cursor.sql("select * from obj")
assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)]
assert obj.count == 1
count_after_first = obj.count
assert count_after_first >= 2

# Call the __arrow_c_stream__ method and pass in the capsule instead
capsule = obj.__arrow_c_stream__()
res = duckdb_cursor.sql("select * from capsule")
assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)]
assert obj.count == 2
assert obj.count == count_after_first + 1

# Ensure __arrow_c_stream__ accepts a requested_schema argument as noop
capsule = obj.__arrow_c_stream__(requested_schema="foo") # noqa: F841
res = duckdb_cursor.sql("select * from capsule")
assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)]
assert obj.count == 3
assert obj.count == count_after_first + 2

def test_capsule_roundtrip(self, duckdb_cursor):
def create_capsule():
Expand Down
Loading