parameterized queries for athena, but not for redshift.

This commit is contained in:
Ari Brown 2023-08-02 16:51:02 -04:00
parent c4e0b71a98
commit 52f44a79fe
6 changed files with 82 additions and 53 deletions

View file

@ -12,9 +12,9 @@ from minerva import parallel_map
pp = pprint.PrettyPrinter(indent=4)
class Redshift:
def __init__(self, profile, output, db=None, cluster=None):
self.session = boto3.session.Session(profile_name=profile)
self.redshift = self.session.client("redshift-data")
def __init__(self, handler, output, db=None, cluster=None):
self.handler = handler
self.client = handler.session.client("redshift-data")
self.output = output
self.database = db
self.cluster = cluster
@ -29,22 +29,16 @@ class Redshift:
e.run()
return e
def download(self, s3):
bucket = s3.split("/")[2]
file = os.path.join(*s3.split("/")[3:])
tmp = f"/tmp/{random.random()}.bin"
self.session.client('s3').download_file(bucket, file, tmp)
return tmp
class Execute:
"""
Execute will not return results, but will execute the SQL and return the final state.
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
"""
def __init__(self, handler, sql):
self.handler = handler
self.redshift = handler.redshift
def __init__(self, redshift, sql):
self.redshift = redshift
self.handler = redshift.handler
self.client = redshift.client
self.sql = sql
self.info_cache = None
@ -52,9 +46,9 @@ class Execute:
return self.sql
def run(self):
resp = self.redshift.execute_statement(Sql=self.query(),
Database=self.handler.database,
ClusterIdentifier=self.handler.cluster)
resp = self.client.execute_statement(Sql=self.query(),
Database=self.redshift.database,
ClusterIdentifier=self.redshift.cluster)
self.query_id = resp['Id']
return resp
@ -62,7 +56,7 @@ class Execute:
return self.info()['Status']
def info(self):
res = self.redshift.describe_statement(Id=self.query_id)
res = self.client.describe_statement(Id=self.query_id)
self.info_cache = res
return self.info_cache
@ -79,7 +73,7 @@ class Query(Execute):
DATA_STYLE = 'parquet'
def query(self):
self.out = os.path.join(self.handler.output,
self.out = os.path.join(self.redshift.output,
str(random.random()))
query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \
f"iam_role default " + \
@ -97,7 +91,7 @@ class Query(Execute):
# Because we're using `UNLOAD`, we get a manifest of the files
# that make up our data.
manif = self.out + "manifest"
tmp = self.handler.download(manif)
tmp = self.handler.s3.download(manif)
with open(tmp, "r") as f:
js = json.load(f)
@ -110,8 +104,8 @@ class Query(Execute):
# TODO parallelize this
def results(self):
#local = [self.handler.download(f) for f in self.manifest_files()]
local = parallel_map(self.handler.download, self.manifest_files())
#local = [self.handler.s3.download(f) for f in self.manifest_files()]
local = parallel_map(self.handler.s3.download, self.manifest_files())
self.ds = pa.dataset.dataset(local)
return self.ds