parameterized queries for athena, but not for redshift.

This commit is contained in:
Ari Brown 2023-08-02 16:51:02 -04:00
parent c4e0b71a98
commit 52f44a79fe
6 changed files with 82 additions and 53 deletions

View file

@ -1,9 +1,13 @@
from .parallel import parallel_map from .parallel import parallel_map
from .athena import Athena from .athena import Athena
from .redshift import Redshift from .redshift import Redshift
from .s3 import S3
from .minerva import Minerva
__all__ = [ __all__ = [
"Athena", "Athena",
"Redshift", "Redshift",
"parallel_map" "parallel_map",
"Minerva",
"S3"
] ]

View file

@ -11,45 +11,41 @@ from minerva import parallel_map
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)
class Athena: class Athena:
def __init__(self, profile, output): def __init__(self, handler, output):
self.session = boto3.session.Session(profile_name=profile) self.handler = handler
self.athena = self.session.client("athena") self.client = handler.session.client("athena")
self.output = output self.output = output
def query(self, sql): def query(self, sql, params):
q = Query(self, sql) q = Query(self, sql, params)
q.run() q.run()
return q return q
def execute(self, sql): def execute(self, sql):
e = Execute(self, sql) e = Execute(self, sql, params)
e.run() e.run()
return e return e
def download(self, s3):
bucket = s3.split("/")[2]
file = os.path.join(*s3.split("/")[3:])
tmp = f"/tmp/{random.random()}.bin"
self.session.client('s3').download_file(bucket, file, tmp)
return tmp
class Execute: class Execute:
""" """
Execute will not return results, but will execute the SQL and return the final state. Execute will not return results, but will execute the SQL and return the final state.
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
""" """
def __init__(self, handler, sql): def __init__(self, athena, sql, params=[]):
self.handler = handler self.athena = athena
self.athena = handler.athena self.handler = athena.handler
self.client = athena.client
self.sql = sql self.sql = sql
self.params = [str(p) for p in params]
self.info_cache = None self.info_cache = None
def query(self): def query(self):
return self.sql return self.sql
def run(self): def run(self):
resp = self.athena.start_query_execution(QueryString=self.query()) resp = self.client.start_query_execution(QueryString=self.query(),
ExecutionParameters=self.params)
self.query_id = resp['QueryExecutionId'] self.query_id = resp['QueryExecutionId']
return resp return resp
@ -57,7 +53,7 @@ class Execute:
return self.info()['Status']['State'] return self.info()['Status']['State']
def info(self): def info(self):
res = self.athena.get_query_execution(QueryExecutionId=self.query_id) res = self.client.get_query_execution(QueryExecutionId=self.query_id)
self.info_cache = res['QueryExecution'] self.info_cache = res['QueryExecution']
return self.info_cache return self.info_cache
@ -74,7 +70,7 @@ class Query(Execute):
DATA_STYLE = 'parquet' DATA_STYLE = 'parquet'
def query(self): def query(self):
out = os.path.join(self.handler.output, out = os.path.join(self.athena.output,
str(random.random())) str(random.random()))
query = f"unload ({self.sql}) to {repr(out)} " + \ query = f"unload ({self.sql}) to {repr(out)} " + \
f"with (format = '{self.DATA_STYLE}')" f"with (format = '{self.DATA_STYLE}')"
@ -91,7 +87,8 @@ class Query(Execute):
# Because we're using `UNLOAD`, we get a manifest of the files # Because we're using `UNLOAD`, we get a manifest of the files
# that make up our data. # that make up our data.
manif = self.info_cache['Statistics']['DataManifestLocation'] manif = self.info_cache['Statistics']['DataManifestLocation']
tmp = self.handler.download(manif) print(manif)
tmp = self.handler.s3.download(manif)
with open(tmp, "r") as f: with open(tmp, "r") as f:
txt = f.read() txt = f.read()
@ -102,10 +99,9 @@ class Query(Execute):
else: else:
return status # canceled or error return status # canceled or error
# TODO parallelize this
def results(self): def results(self):
#local = [self.handler.download(f) for f in self.manifest_files()] #local = [self.handler.s3.download(f) for f in self.manifest_files()]
local = parallel_map(self.handler.download, self.manifest_files()) local = parallel_map(self.handler.s3.download, self.manifest_files())
self.ds = pa.dataset.dataset(local) self.ds = pa.dataset.dataset(local)
return self.ds return self.ds

View file

@ -12,9 +12,9 @@ from minerva import parallel_map
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)
class Redshift: class Redshift:
def __init__(self, profile, output, db=None, cluster=None): def __init__(self, handler, output, db=None, cluster=None):
self.session = boto3.session.Session(profile_name=profile) self.handler = handler
self.redshift = self.session.client("redshift-data") self.client = handler.session.client("redshift-data")
self.output = output self.output = output
self.database = db self.database = db
self.cluster = cluster self.cluster = cluster
@ -29,22 +29,16 @@ class Redshift:
e.run() e.run()
return e return e
def download(self, s3):
bucket = s3.split("/")[2]
file = os.path.join(*s3.split("/")[3:])
tmp = f"/tmp/{random.random()}.bin"
self.session.client('s3').download_file(bucket, file, tmp)
return tmp
class Execute: class Execute:
""" """
Execute will not return results, but will execute the SQL and return the final state. Execute will not return results, but will execute the SQL and return the final state.
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
""" """
def __init__(self, handler, sql): def __init__(self, redshift, sql):
self.handler = handler self.redshift = redshift
self.redshift = handler.redshift self.handler = redshift.handler
self.client = redshift.client
self.sql = sql self.sql = sql
self.info_cache = None self.info_cache = None
@ -52,9 +46,9 @@ class Execute:
return self.sql return self.sql
def run(self): def run(self):
resp = self.redshift.execute_statement(Sql=self.query(), resp = self.client.execute_statement(Sql=self.query(),
Database=self.handler.database, Database=self.redshift.database,
ClusterIdentifier=self.handler.cluster) ClusterIdentifier=self.redshift.cluster)
self.query_id = resp['Id'] self.query_id = resp['Id']
return resp return resp
@ -62,7 +56,7 @@ class Execute:
return self.info()['Status'] return self.info()['Status']
def info(self): def info(self):
res = self.redshift.describe_statement(Id=self.query_id) res = self.client.describe_statement(Id=self.query_id)
self.info_cache = res self.info_cache = res
return self.info_cache return self.info_cache
@ -79,7 +73,7 @@ class Query(Execute):
DATA_STYLE = 'parquet' DATA_STYLE = 'parquet'
def query(self): def query(self):
self.out = os.path.join(self.handler.output, self.out = os.path.join(self.redshift.output,
str(random.random())) str(random.random()))
query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \ query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \
f"iam_role default " + \ f"iam_role default " + \
@ -97,7 +91,7 @@ class Query(Execute):
# Because we're using `UNLOAD`, we get a manifest of the files # Because we're using `UNLOAD`, we get a manifest of the files
# that make up our data. # that make up our data.
manif = self.out + "manifest" manif = self.out + "manifest"
tmp = self.handler.download(manif) tmp = self.handler.s3.download(manif)
with open(tmp, "r") as f: with open(tmp, "r") as f:
js = json.load(f) js = json.load(f)
@ -110,8 +104,8 @@ class Query(Execute):
# TODO parallelize this # TODO parallelize this
def results(self): def results(self):
#local = [self.handler.download(f) for f in self.manifest_files()] #local = [self.handler.s3.download(f) for f in self.manifest_files()]
local = parallel_map(self.handler.download, self.manifest_files()) local = parallel_map(self.handler.s3.download, self.manifest_files())
self.ds = pa.dataset.dataset(local) self.ds = pa.dataset.dataset(local)
return self.ds return self.ds

View file

@ -0,0 +1,33 @@
import boto3
import os
import random
class S3:
def __init__(self, handler):
self.handler = handler
self.s3 = handler.session.client("s3")
def parse(self, uri):
bucket = uri.split("/")[2]
file = os.path.join(*uri.split("/")[3:])
return bucket, file
def download(self, uri, loc=None):
loc = loc or f"/tmp/{random.random()}.bin"
bucket, key = self.parse(uri)
self.s3.download_file(bucket, key, loc)
return loc
def read(self, uri):
bucket, key = self.parse(uri)
file = self.s3.get_object(Bucket=bucket, Key=key)
return file['Body'].read().decode('utf-8')
def upload(self, local, remote):
# If `remote` is a directory, pick a name for the file
if remote[-1] == "/":
remote = os.path.join(remote, os.path.basename(local))
bucket, key = self.parse(remote)
return self.s3.upload_file(local, bucket, key)

View file

@ -1,9 +1,10 @@
import minerva as m import minerva
import pprint import pprint
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)
athena = m.Athena("hay", "s3://haystac-pmo-athena/") m = minerva.Minerva("hay")
athena = m.athena("s3://haystac-pmo-athena/")
#query = athena.query( #query = athena.query(
#"""SELECT * #"""SELECT *
#FROM trajectories.kitware #FROM trajectories.kitware
@ -12,7 +13,7 @@ athena = m.Athena("hay", "s3://haystac-pmo-athena/")
# ST_Point(longitude, latitude) # ST_Point(longitude, latitude)
#) #)
#""") #""")
query = athena.query("select count(*) as count from trajectories.kitware") query = athena.query("select count(*) as count from trajectories.kitware where agent = ?", [4])
data = query.results() data = query.results()
pp.pprint(data.head(10)) pp.pprint(data.head(10))
print(query.runtime) print(query.runtime)

View file

@ -1,12 +1,13 @@
import minerva as m import minerva
import pprint import pprint
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)
red = m.Redshift("hay", "s3://haystac-pmo-athena/", m = minerva.Minerva("hay")
db="dev", red = m.redshift("s3://haystac-pmo-athena/",
cluster="redshift-cluster-1") db="dev",
query = red.query("select count(*) from myspectrum_schema.kitware") cluster="redshift-cluster-1")
query = red.query("select count(*) from myspectrum_schema.kitware where agent = 4")
data = query.results() data = query.results()
pp.pprint(data.head(10)) pp.pprint(data.head(10))
print(query.runtime) print(query.runtime)