Skip to content

Commit 5959f54

Browse files
committed
statement fixes
1 parent cb30bd3 commit 5959f54

File tree

3 files changed

+91
-16
lines changed

3 files changed

+91
-16
lines changed

pystackql/base_stackql_magic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import print_function
22
from IPython.core.magic import (Magics)
33
from string import Template
4+
import pandas as pd
45

56
class BaseStackqlMagic(Magics):
67
"""Base Jupyter magic extension enabling running StackQL queries.
@@ -37,4 +38,8 @@ def run_query(self, query):
3738
:return: Query results, returned as a Pandas DataFrame.
3839
:rtype: pandas.DataFrame
3940
"""
41+
# Check if the query starts with "registry pull" (case insensitive)
42+
if query.strip().lower().startswith("registry pull"):
43+
return self.stackql_instance.executeStmt(query)
44+
4045
return self.stackql_instance.execute(query)

pystackql/stackql.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,24 @@ def _connect_to_server(self):
124124
print(f"Unexpected error while connecting to the server: {e}")
125125
return None
126126

127-
def _run_server_query(self, query):
127+
def _run_server_query(self, query, is_statement=False):
128128
"""Runs a query against the server using psycopg2.
129129
130130
:param query: SQL query to be executed on the server.
131131
:type query: str
132132
:return: List of result rows if the query fetches results; empty list if there are no results.
133-
:rtype: list
133+
:rtype: list of dict objects
134134
:raises: psycopg2.ProgrammingError for issues related to the SQL query,
135135
unless the error is "no results to fetch", in which case an empty list is returned.
136136
"""
137137
try:
138138
cur = self._conn.cursor(cursor_factory=RealDictCursor)
139139
cur.execute(query)
140+
if is_statement:
141+
# If the query is a statement, there are no results to fetch.
142+
result_msg = cur.statusmessage
143+
cur.close()
144+
return [{'message': result_msg}]
140145
rows = cur.fetchall()
141146
cur.close()
142147
return rows
@@ -146,7 +151,7 @@ def _run_server_query(self, query):
146151
else:
147152
raise
148153

149-
def _run_query(self, query, is_statement=False):
154+
def _run_query(self, query):
150155
"""Internal method to execute a StackQL query using a subprocess.
151156
152157
The method spawns a subprocess to run the StackQL binary with the specified query and parameters.
@@ -395,7 +400,7 @@ def executeStmt(self, query):
395400
against the server. Otherwise, it executes the query using a subprocess.
396401
397402
:param query: The StackQL query string to be executed.
398-
:type query: str
403+
:type query: str, list of dict objects, or Pandas DataFrame
399404
400405
:return: The output result of the query in string format. If in `server_mode`, it
401406
returns a JSON string representation of the result.
@@ -406,14 +411,26 @@ def executeStmt(self, query):
406411
>>> stackql = StackQL()
407412
>>> stackql_query = "REGISTRY PULL okta"
408413
>>> result = stackql.executeStmt(stackql_query)
409-
>>> print(result)
414+
>>> result
410415
"""
411416
if self.server_mode:
412417
# Use server mode
413-
result = self._run_server_query(query)
414-
return json.dumps(result)
418+
result = self._run_server_query(query, True)
419+
if self.output == 'pandas':
420+
return pd.DataFrame(result)
421+
elif self.output == 'csv':
422+
# return the string representation of the result
423+
return result[0]['message']
424+
else:
425+
return result
415426
else:
416-
return self._run_query(query, is_statement=True)
427+
result_msg = self._run_query(query)
428+
if self.output == 'pandas':
429+
return pd.DataFrame({'message': [result_msg]})
430+
elif self.output == 'csv':
431+
return result_msg
432+
else:
433+
return [{'message': result_msg}]
417434

418435
def execute(self, query):
419436
"""Executes a query using the StackQL instance and returns the output

tests/pystackql_tests.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,37 @@ def test_09_csv_output_with_header(self):
155155

156156
@pystackql_test_setup()
157157
def test_10_executeStmt(self):
158+
okta_result_dict = self.stackql.executeStmt(registry_pull_okta_query)
159+
okta_result = okta_result_dict[0]["message"]
160+
expected_pattern = registry_pull_resp_pattern("okta")
161+
self.assertTrue(re.search(expected_pattern, okta_result), f"Expected pattern not found in result: {okta_result}")
162+
github_result_dict = self.stackql.executeStmt(registry_pull_github_query)
163+
github_result = github_result_dict[0]["message"]
164+
expected_pattern = registry_pull_resp_pattern("github")
165+
self.assertTrue(re.search(expected_pattern, github_result), f"Expected pattern not found in result: {github_result}")
166+
print_test_result(f"""Test executeStmt method\nRESULTS:\n{okta_result_dict}\n{github_result_dict}""", True)
167+
168+
@pystackql_test_setup(output="csv")
169+
def test_10a_executeStmt_with_csv_output(self):
158170
okta_result = self.stackql.executeStmt(registry_pull_okta_query)
159171
expected_pattern = registry_pull_resp_pattern("okta")
160172
self.assertTrue(re.search(expected_pattern, okta_result), f"Expected pattern not found in result: {okta_result}")
161173
github_result = self.stackql.executeStmt(registry_pull_github_query)
162174
expected_pattern = registry_pull_resp_pattern("github")
163175
self.assertTrue(re.search(expected_pattern, github_result), f"Expected pattern not found in result: {github_result}")
164-
print_test_result(f"""Test executeStmt method\nRESULTS:\n{okta_result}{github_result}""", True)
176+
print_test_result(f"""Test executeStmt method with csv output\nRESULTS:\n{okta_result}\n{github_result}""", True)
177+
178+
@pystackql_test_setup(output="pandas")
179+
def test_10b_executeStmt_with_pandas_output(self):
180+
okta_result_df = self.stackql.executeStmt(registry_pull_okta_query)
181+
okta_result = okta_result_df['message'].iloc[0]
182+
expected_pattern = registry_pull_resp_pattern("okta")
183+
self.assertTrue(re.search(expected_pattern, okta_result), f"Expected pattern not found in result: {okta_result}")
184+
github_result_df = self.stackql.executeStmt(registry_pull_github_query)
185+
github_result = github_result_df['message'].iloc[0]
186+
expected_pattern = registry_pull_resp_pattern("github")
187+
self.assertTrue(re.search(expected_pattern, github_result), f"Expected pattern not found in result: {github_result}")
188+
print_test_result(f"""Test executeStmt method with pandas output\nRESULTS:\n{okta_result_df}\n{github_result_df}""", True)
165189

166190
@pystackql_test_setup()
167191
def test_11_execute_with_defaults(self):
@@ -232,13 +256,16 @@ def test_19_server_mode_connectivity(self):
232256
@pystackql_test_setup(server_mode=True)
233257
def test_20_executeStmt_server_mode(self):
234258
result = self.stackql.executeStmt(registry_pull_google_query)
235-
is_valid_json_string_of_empty_list = False
236-
try:
237-
parsed_result = json.loads(result)
238-
is_valid_json_string_of_empty_list = isinstance(parsed_result, list) and len(parsed_result) == 0
239-
except json.JSONDecodeError:
240-
pass
241-
print_test_result("Test executeStmt in server mode", is_valid_json_string_of_empty_list, True)
259+
# Checking if the result is a list containing a single dictionary with a key 'message' and value 'OK'
260+
is_valid_response = isinstance(result, list) and len(result) == 1 and result[0].get('message') == 'OK'
261+
print_test_result(f"Test executeStmt in server mode\n{result}", is_valid_response, True)
262+
263+
@pystackql_test_setup(server_mode=True, output='pandas')
264+
def test_20a_executeStmt_server_mode_with_pandas_output(self):
265+
result_df = self.stackql.executeStmt(registry_pull_google_query)
266+
# Verifying if the result is a dataframe with a column 'message' containing the value 'OK' in its first row
267+
is_valid_response = isinstance(result_df, pd.DataFrame) and 'message' in result_df.columns and result_df['message'].iloc[0] == 'OK'
268+
print_test_result(f"Test executeStmt in server mode with pandas output\n{result_df}", is_valid_response, True)
242269

243270
@pystackql_test_setup(server_mode=True)
244271
def test_21_execute_server_mode_default_output(self):
@@ -288,6 +315,7 @@ def setUp(self):
288315
self.stackql_magic = self.MAGIC_CLASS(shell=self.shell)
289316
self.query = "SELECT 1 as fred"
290317
self.expected_result = pd.DataFrame({"fred": [1]})
318+
self.statement = "REGISTRY PULL github"
291319

292320
def print_test_result(self, test_name, *checks):
293321
all_passed = all(checks)
@@ -320,6 +348,31 @@ def test_cell_magic_query_no_output(self):
320348
checks = self.run_magic_test(line="--no-display", cell=self.query, expect_none=True)
321349
self.print_test_result("Cell magic test (with --no-display)", *checks)
322350

351+
def run_magic_statement_test(self, line, cell, expect_none=False):
352+
# Execute the magic with our statement.
353+
result = self.stackql_magic.stackql(line=line, cell=cell)
354+
# Validate the outcome.
355+
checks = []
356+
if expect_none:
357+
checks.append(result is None)
358+
else:
359+
# Check that the output contains expected content
360+
checks.append("OK" in result["message"].iloc[0])
361+
checks.append('stackql_df' in self.shell.user_ns)
362+
checks.append("OK" in self.shell.user_ns['stackql_df']["message"].iloc[0])
363+
return checks
364+
365+
def test_line_magic_statement(self):
366+
checks = self.run_magic_statement_test(line=self.statement, cell=None)
367+
self.print_test_result("Line magic statement", *checks)
368+
369+
def test_cell_magic_statement(self):
370+
checks = self.run_magic_statement_test(line="", cell=self.statement)
371+
self.print_test_result("Cell magic statement", *checks)
372+
373+
def test_cell_magic_statement_no_output(self):
374+
checks = self.run_magic_statement_test(line="--no-display", cell=self.statement, expect_none=True)
375+
self.print_test_result("Cell magic statement (with --no-display)", *checks)
323376

324377
class StackQLMagicTests(BaseStackQLMagicTests, unittest.TestCase):
325378

0 commit comments

Comments
 (0)