import boto3 import os import random import time import pyarrow as pa import pyarrow.dataset import pprint import datetime import json from minerva import parallel_map pp = pprint.PrettyPrinter(indent=4) class Redshift: def __init__(self, handler, output, db=None, cluster=None, workgroup=None, secret=None): self.handler = handler self.client = handler.session.client("redshift-data") self.output = output self.database = db self.cluster = cluster self.workgroup = workgroup self.secret = secret if workgroup: cli = handler.session.client("redshift-serverless") wg = cli.get_workgroup(workgroupName=workgroup) self.rpus = wg['workgroup']['baseCapacity'] def query(self, sql): q = Query(self, sql) q.run() return q def execute(self, sql): e = Execute(self, sql) e.run() return e class Execute: """ Execute will not return results, but will execute the SQL and return the final state. Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE """ def __init__(self, redshift, sql): self.redshift = redshift self.handler = redshift.handler self.client = redshift.client self.sql = sql self.info_cache = None self.status_cache = None self.ds = None self.files = None self.temps = [] def query(self): return self.sql def run(self): if self.redshift.cluster: resp = self.client.execute_statement(Sql = self.query(), Database = self.redshift.database, ClusterIdentifier = self.redshift.cluster) else: params = {"WorkgroupName": self.redshift.workgroup} if self.redshift.secret: params['SecretArn'] = self.redshift.secret resp = self.client.execute_statement(Sql=self.query(), Database=self.redshift.database, **params) self.query_id = resp['Id'] return resp def status(self): return self.info()['Status'] def info(self): if self.status_cache in ['FINISHED', 'ABORTED', 'FAILED']: return self.info_cache res = self.client.describe_statement(Id=self.query_id) self.info_cache = res self.status_cache = res['Status'] return self.info_cache # Block until the SQL has finished running def finish(self): stat = self.status() while stat in ['SUBMITTED', 'PICKED', 'STARTED']: time.sleep(5) stat = self.status() self.runtime = self.info_cache['UpdatedAt'] - self.info_cache['CreatedAt'] if self.redshift.rpus: # $0.36 / RPU-hour self.cost = 0.36 * self.redshift.rpus * self.runtime.seconds / 3600.0 return stat # finalized state class Query(Execute): DATA_STYLE = 'parquet' def query(self): self.out = os.path.join(self.redshift.output, "results", str(random.random()) + ".") #query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \ # f"iam_role default " + \ # f"format as {self.DATA_STYLE} " + \ # f"manifest" query = f""" create temp table temp_data as {self.sql}; unload ('select * from temp_data') to {repr(self.out)} iam_role default format as {self.DATA_STYLE} manifest; drop table temp_data; """ print(query) return query def manifest_files(self): if self.files: return self.files status = self.finish() if status == "FINISHED": if self.info_cache['ResultRows'] != 0: # Because we're using `UNLOAD`, we get a manifest of the files # that make up our data. manif = self.out + "manifest" tmp = self.handler.s3.download(manif) with open(tmp, "r") as f: js = json.load(f) # Filter empty strings self.files = [e['url'].strip() for e in js['entries'] if e['url'].strip()] else: # no results returned, so no manifest file was created self.files = [] return self.files else: return status # canceled or error def results(self): # if it's not a list, then we've failed if type(self.manifest_files()) != type([]): raise Exception(f"""Query has status {self.status()} did not complete and thus has no results""") self.temps = [self.handler.s3.download(f) for f in self.manifest_files()] #local = parallel_map(self.handler.s3.download, self.manifest_files()) self.ds = pa.dataset.dataset(self.temps) return self.ds # Return scalar results # Abstracts away a bunch of keystrokes def scalar(self): return self.results().head(1)[0][0].as_py() def __enter__(self): return self def __exit__(self, exception_type, exception_value, exception_traceback): self.close() def close(self): if self.temps: for file in self.temps: os.remove(file)