From 0e5542cecf4a3fad2343cf2677583cde7bd2d17b Mon Sep 17 00:00:00 2001 From: Ari Brown Date: Thu, 3 Aug 2023 11:42:56 -0400 Subject: [PATCH] better support for params within athena --- minerva/athena.py | 18 ++++++++++++------ minerva/parallel.py | 2 ++ test.py | 22 ++++++++++++++++++---- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/minerva/athena.py b/minerva/athena.py index 3ed9a0d..6c12e48 100644 --- a/minerva/athena.py +++ b/minerva/athena.py @@ -16,12 +16,12 @@ class Athena: self.client = handler.session.client("athena") self.output = output - def query(self, sql, params): + def query(self, sql, params=[]): q = Query(self, sql, params) q.run() return q - def execute(self, sql): + def execute(self, sql, params=[]): e = Execute(self, sql, params) e.run() return e @@ -44,8 +44,12 @@ class Execute: return self.sql def run(self): - resp = self.client.start_query_execution(QueryString=self.query(), - ExecutionParameters=self.params) + if self.params: + resp = self.client.start_query_execution(QueryString=self.query(), + ExecutionParameters=self.params) + else: + resp = self.client.start_query_execution(QueryString=self.query()) + self.query_id = resp['QueryExecutionId'] return resp @@ -87,7 +91,6 @@ class Query(Execute): # Because we're using `UNLOAD`, we get a manifest of the files # that make up our data. manif = self.info_cache['Statistics']['DataManifestLocation'] - print(manif) tmp = self.handler.s3.download(manif) with open(tmp, "r") as f: txt = f.read() @@ -97,7 +100,10 @@ class Query(Execute): return files else: - return status # canceled or error + print("Error") + print(self.info_cache['Status']['AthenaError']['ErrorMessage']) + raise + #return status # canceled or error def results(self): #local = [self.handler.s3.download(f) for f in self.manifest_files()] diff --git a/minerva/parallel.py b/minerva/parallel.py index df11ab8..8e97ead 100644 --- a/minerva/parallel.py +++ b/minerva/parallel.py @@ -24,4 +24,6 @@ def parallel_map(func=None, data=None, cores=8): res = Parallel(n_jobs=cores)(delayed(wrapper_func)(group) for group in groups) + # Flatten the nested lists return [val for r in res for val in r] + diff --git a/test.py b/test.py index a87c713..8c64435 100644 --- a/test.py +++ b/test.py @@ -5,6 +5,7 @@ pp = pprint.PrettyPrinter(indent=4) m = minerva.Minerva("hay") athena = m.athena("s3://haystac-pmo-athena/") + #query = athena.query( #"""SELECT * #FROM trajectories.kitware @@ -13,13 +14,26 @@ athena = m.athena("s3://haystac-pmo-athena/") # ST_Point(longitude, latitude) #) #""") -query = athena.query("select count(*) as count from trajectories.kitware where agent = ?", [4]) -data = query.results() -pp.pprint(data.head(10)) -print(query.runtime) # Everything *needs* to have a column in order for parquet to work, so scalar # values have to be assigned something, so here we use `as count` to create # a temporary column called `count` #print(athena.query("select count(*) as count from trajectories.kitware").results().head(1)) +query = athena.query( + """ + select round(longitude, 3) as lon, count(*) as count + from trajectories.kitware + where agent = 4 + group by round(longitude, 3) + order by count(*) desc + """ +) +data = query.results() +pp.pprint(data.head(10)) +print(query.runtime) + +#import IPython +#IPython.embed() + +