better support for params within athena

This commit is contained in:
Ari Brown 2023-08-03 11:42:56 -04:00
parent ea31c8e8c0
commit 0e5542cecf
3 changed files with 32 additions and 10 deletions

View file

@ -16,12 +16,12 @@ class Athena:
self.client = handler.session.client("athena")
self.output = output
def query(self, sql, params):
def query(self, sql, params=[]):
q = Query(self, sql, params)
q.run()
return q
def execute(self, sql):
def execute(self, sql, params=[]):
e = Execute(self, sql, params)
e.run()
return e
@ -44,8 +44,12 @@ class Execute:
return self.sql
def run(self):
resp = self.client.start_query_execution(QueryString=self.query(),
ExecutionParameters=self.params)
if self.params:
resp = self.client.start_query_execution(QueryString=self.query(),
ExecutionParameters=self.params)
else:
resp = self.client.start_query_execution(QueryString=self.query())
self.query_id = resp['QueryExecutionId']
return resp
@ -87,7 +91,6 @@ class Query(Execute):
# Because we're using `UNLOAD`, we get a manifest of the files
# that make up our data.
manif = self.info_cache['Statistics']['DataManifestLocation']
print(manif)
tmp = self.handler.s3.download(manif)
with open(tmp, "r") as f:
txt = f.read()
@ -97,7 +100,10 @@ class Query(Execute):
return files
else:
return status # canceled or error
print("Error")
print(self.info_cache['Status']['AthenaError']['ErrorMessage'])
raise
#return status # canceled or error
def results(self):
#local = [self.handler.s3.download(f) for f in self.manifest_files()]

View file

@ -24,4 +24,6 @@ def parallel_map(func=None, data=None, cores=8):
res = Parallel(n_jobs=cores)(delayed(wrapper_func)(group) for group in groups)
# Flatten the nested lists
return [val for r in res for val in r]

22
test.py
View file

@ -5,6 +5,7 @@ pp = pprint.PrettyPrinter(indent=4)
m = minerva.Minerva("hay")
athena = m.athena("s3://haystac-pmo-athena/")
#query = athena.query(
#"""SELECT *
#FROM trajectories.kitware
@ -13,13 +14,26 @@ athena = m.athena("s3://haystac-pmo-athena/")
# ST_Point(longitude, latitude)
#)
#""")
query = athena.query("select count(*) as count from trajectories.kitware where agent = ?", [4])
data = query.results()
pp.pprint(data.head(10))
print(query.runtime)
# Everything *needs* to have a column in order for parquet to work, so scalar
# values have to be assigned something, so here we use `as count` to create
# a temporary column called `count`
#print(athena.query("select count(*) as count from trajectories.kitware").results().head(1))
query = athena.query(
"""
select round(longitude, 3) as lon, count(*) as count
from trajectories.kitware
where agent = 4
group by round(longitude, 3)
order by count(*) desc
"""
)
data = query.results()
pp.pprint(data.head(10))
print(query.runtime)
#import IPython
#IPython.embed()