Skip to content

Commit 58388ce

Browse files
committed
fix: fix streaming dataframe output
1 parent d95e41a commit 58388ce

File tree

10 files changed

+125
-28
lines changed

10 files changed

+125
-28
lines changed

programs/local/ChdbClient.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
#if USE_PYTHON
1616
#include <PythonTableCache.h>
17+
#include <PandasDataFrameBuilder.h>
18+
#include <pybind11/pybind11.h>
19+
namespace py = pybind11;
1720
#endif
1821

1922
namespace DB
@@ -318,14 +321,21 @@ CHDB::QueryResultPtr ChdbClient::executeStreamingIterate(void * streaming_result
318321
#if USE_PYTHON
319322
if (Poco::toLower(default_output_format) == "dataframe")
320323
{
321-
res = std::make_unique<CHDB::ChunkQueryResult>(
324+
auto rows_read = processed_rows - old_processed_rows;
325+
auto chunk_result = std::make_unique<CHDB::ChunkQueryResult>(
322326
std::move(collected_chunks),
323327
std::move(collected_chunks_header),
324328
elapsed_time - old_elapsed_time,
325-
processed_rows - old_processed_rows,
329+
rows_read,
326330
processed_bytes - old_processed_bytes,
327331
storage_rows_read - old_storage_rows_read,
328332
storage_bytes_read - old_storage_bytes_read);
333+
334+
py::gil_scoped_acquire acquire;
335+
CHDB::PandasDataFrameBuilder builder(*chunk_result);
336+
py::handle df = builder.getDataFrame().release();
337+
338+
res = std::make_unique<CHDB::DataFrameQueryResult>(df, rows_read);
329339
}
330340
else
331341
#endif
@@ -350,7 +360,7 @@ CHDB::QueryResultPtr ChdbClient::executeStreamingIterate(void * streaming_result
350360
}
351361
}
352362

353-
// Check if query should end based on result type
363+
/// Check if query should end based on result type
354364
bool is_end = !res->getError().empty() || is_canceled || res->isEmpty();
355365
if (is_end)
356366
{

programs/local/EmbeddedServer.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@
7474
# include <azure/storage/common/internal/xml_wrapper.hpp>
7575
#endif
7676

77-
bool chdb_embedded_server_initialized = false;
78-
7977
namespace fs = std::filesystem;
8078

8179
namespace CurrentMetrics
@@ -513,8 +511,6 @@ try
513511
global_register_once_flag,
514512
[]()
515513
{
516-
chdb_embedded_server_initialized = true;
517-
518514
registerInterpreters();
519515
/// Don't initialize DateLUT
520516
registerFunctions();

programs/local/LocalChdb.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "PandasDataFrameBuilder.h"
44
#include "ChunkCollectorOutputFormat.h"
55
#include "PythonImporter.h"
6-
#include "PythonTableCache.h"
76
#include "StoragePython.h"
87

98
#include <pybind11/detail/non_limited_api.h>
@@ -274,13 +273,14 @@ query_result * connection_wrapper::query(const std::string & query_str, const st
274273

275274
auto * result = chdb_query_n(*conn, query_str.data(), query_str.size(), format.data(), format.size());
276275

277-
auto error_msg = CHDB::chdb_result_error_string(result);
276+
const auto & error_msg = CHDB::chdb_result_error_string(result);
278277
if (!error_msg.empty())
279278
{
280279
std::string msg_copy(error_msg);
281280
chdb_destroy_query_result(result);
282281
throw std::runtime_error(msg_copy);
283282
}
283+
284284
return new query_result(result, false);
285285
}
286286

@@ -298,7 +298,7 @@ py::object connection_wrapper::query_df(const std::string & query_str)
298298

299299
result = chdb_query_n(*conn, query_str.data(), query_str.size(), format.data(), format.size());
300300

301-
auto error_msg = CHDB::chdb_result_error_string(result);
301+
const auto & error_msg = CHDB::chdb_result_error_string(result);
302302
if (!error_msg.empty())
303303
{
304304
std::string msg_copy(error_msg);
@@ -322,7 +322,7 @@ streaming_query_result * connection_wrapper::send_query(const std::string & quer
322322
CHDB::cachePythonTablesFromQuery(reinterpret_cast<chdb_conn *>(*conn), query_str);
323323
py::gil_scoped_release release;
324324
auto * result = chdb_stream_query_n(*conn, query_str.data(), query_str.size(), format.data(), format.size());
325-
auto error_msg = CHDB::chdb_result_error_string(result);
325+
const auto & error_msg = CHDB::chdb_result_error_string(result);
326326
if (!error_msg.empty())
327327
{
328328
std::string msg_copy(error_msg);
@@ -342,7 +342,7 @@ query_result * connection_wrapper::streaming_fetch_result(streaming_query_result
342342

343343
auto * result = chdb_stream_fetch_result(*conn, streaming_result->get_result());
344344

345-
const auto error_msg = CHDB::chdb_result_error_string(result);
345+
const auto & error_msg = CHDB::chdb_result_error_string(result);
346346
if (!error_msg.empty())
347347
{
348348
std::string msg_copy(error_msg);
@@ -359,30 +359,29 @@ py::object connection_wrapper::streaming_fetch_df(streaming_query_result * strea
359359
return py::none();
360360

361361
chdb_result * result = nullptr;
362-
CHDB::ChunkQueryResult * chunk_result = nullptr;
362+
CHDB::DataFrameQueryResult * chunk_result = nullptr;
363363

364364
{
365365
py::gil_scoped_release release;
366366

367-
result = chdb_stream_fetch_result(*conn, streaming_result->get_result());
367+
result = chdb_stream_fetch_result(*conn, streaming_result->get_result());
368368

369-
auto error_msg = CHDB::chdb_result_error_string(result);
369+
const auto & error_msg = CHDB::chdb_result_error_string(result);
370370
if (!error_msg.empty())
371371
{
372372
std::string msg_copy(error_msg);
373373
chdb_destroy_query_result(result);
374374
throw std::runtime_error(msg_copy);
375375
}
376376

377-
if (!(chunk_result = dynamic_cast<CHDB::ChunkQueryResult *>(reinterpret_cast<CHDB::QueryResult*>(result))))
378-
throw std::runtime_error("Expected ChunkQueryResult for dataframe format");
377+
if (!(chunk_result = dynamic_cast<CHDB::DataFrameQueryResult *>(reinterpret_cast<CHDB::QueryResult*>(result))))
378+
throw std::runtime_error("Expected DataFrameQueryResult for dataframe format");
379379
}
380380

381-
CHDB::PandasDataFrameBuilder builder(*chunk_result);
382-
auto df = builder.getDataFrame();
381+
py::handle df_handle = chunk_result->dataframe;
383382
chdb_destroy_query_result(result);
384383

385-
return df;
384+
return py::reinterpret_steal<py::object>(df_handle);
386385
}
387386

388387
void connection_wrapper::streaming_cancel_query(streaming_query_result * streaming_result)

programs/local/LocalServer.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@
7373
# include <azure/storage/common/internal/xml_wrapper.hpp>
7474
#endif
7575

76-
extern bool chdb_embedded_server_initialized;
77-
7876
namespace fs = std::filesystem;
7977

8078
namespace CurrentMetrics
@@ -650,8 +648,6 @@ try
650648
global_register_once_flag,
651649
[]()
652650
{
653-
chdb_embedded_server_initialized = true;
654-
655651
registerInterpreters();
656652
/// Don't initialize DateLUT
657653
registerFunctions();

programs/local/QueryResult.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
#if USE_PYTHON
1111
#include <Processors/Chunk.h>
12+
#include <pybind11/pybind11.h>
13+
namespace py = pybind11;
1214
namespace DB
1315
{
1416
class Block;
@@ -23,7 +25,8 @@ enum class QueryResultType : uint8_t
2325
RESULT_TYPE_MATERIALIZED = 0,
2426
RESULT_TYPE_STREAMING = 1,
2527
RESULT_TYPE_CHUNK = 2,
26-
RESULT_TYPE_NONE = 3
28+
RESULT_TYPE_DATAFRAME = 3,
29+
RESULT_TYPE_NONE = 4
2730
};
2831

2932
class QueryResult
@@ -144,13 +147,34 @@ class ChunkQueryResult : public QueryResult
144147
uint64_t storage_rows_read;
145148
uint64_t storage_bytes_read;
146149
};
150+
151+
class DataFrameQueryResult : public QueryResult
152+
{
153+
public:
154+
explicit DataFrameQueryResult(
155+
py::handle dataframe_,
156+
uint64_t rows_read)
157+
: QueryResult(QueryResultType::RESULT_TYPE_DATAFRAME),
158+
dataframe(dataframe_),
159+
is_empty(rows_read == 0)
160+
{}
161+
162+
bool isEmpty() const override
163+
{
164+
return is_empty;
165+
}
166+
167+
py::handle dataframe;
168+
bool is_empty;
169+
};
147170
#endif
148171

149172
using QueryResultPtr = std::unique_ptr<QueryResult>;
150173
using MaterializedQueryResultPtr = std::unique_ptr<MaterializedQueryResult>;
151174
using StreamQueryResultPtr = std::unique_ptr<StreamQueryResult>;
152175
#if USE_PYTHON
153176
using ChunkQueryResultPtr = std::unique_ptr<ChunkQueryResult>;
177+
using DataFrameQueryResultPtr = std::unique_ptr<DataFrameQueryResult>;
154178
#endif
155179

156180
} // namespace CHDB

programs/local/chdb-arrow.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,17 @@ chdb_state chdb_arrow_scan(
134134
chdb_connection conn, const char * table_name,
135135
chdb_arrow_stream arrow_stream)
136136
{
137+
CHDB::ChdbMemoryTrackingGuard guard;
138+
137139
return chdb_inner_arrow_scan(conn, table_name, arrow_stream, false);
138140
}
139141

140142
chdb_state chdb_arrow_array_scan(
141143
chdb_connection conn, const char * table_name,
142144
chdb_arrow_schema arrow_schema, chdb_arrow_array arrow_array)
143145
{
146+
CHDB::ChdbMemoryTrackingGuard guard;
147+
144148
auto * private_data = new CHDB::PrivateData();
145149
private_data->schema = reinterpret_cast<ArrowSchema *>(arrow_schema);
146150
private_data->array = reinterpret_cast<ArrowArray *>(arrow_array);
@@ -158,6 +162,8 @@ chdb_state chdb_arrow_array_scan(
158162

159163
chdb_state chdb_arrow_unregister_table(chdb_connection conn, const char * table_name)
160164
{
165+
CHDB::ChdbMemoryTrackingGuard guard;
166+
161167
if (!table_name)
162168
return CHDBError;
163169

programs/local/chdb-internal.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,29 @@ inline bool checkConnectionValidity(chdb_conn * connection)
1818
return connection && connection->connected;
1919
}
2020

21+
extern thread_local bool chdb_memory_tracking;
22+
2123
namespace CHDB
2224
{
2325

26+
class ChdbMemoryTrackingGuard
27+
{
28+
public:
29+
ChdbMemoryTrackingGuard()
30+
: previous_value(chdb_memory_tracking)
31+
{
32+
chdb_memory_tracking = true;
33+
}
34+
35+
~ChdbMemoryTrackingGuard()
36+
{
37+
chdb_memory_tracking = previous_value;
38+
}
39+
40+
private:
41+
bool previous_value;
42+
};
43+
2444
std::unique_ptr<MaterializedQueryResult> pyEntryClickHouseLocal(int argc, char ** argv);
2545

2646
const std::string & chdb_result_error_string(chdb_result * result);

0 commit comments

Comments
 (0)