forked from bellwether/minerva
better redshift support, working on dask stuff
This commit is contained in:
parent
e854a93e60
commit
5eb8471081
5 changed files with 86 additions and 175 deletions
|
|
@ -113,7 +113,7 @@ class Machine:
|
|||
self.description = resp['Reservations'][0]['Instances'][0]
|
||||
self.public_ip = self.description['PublicIpAddress']
|
||||
|
||||
print(f"\t{self.name} ({self.info['InstanceId']}) => {self.public_ip} ({self.private_ip})")
|
||||
print(f"\t{self.name} ({self.info['InstanceId']}\t- {self.instance_type}) => {self.public_ip} ({self.private_ip})")
|
||||
|
||||
ip = self.public_ip or self.private_ip
|
||||
self.ssh = Connection(ip,
|
||||
|
|
|
|||
|
|
@ -12,12 +12,19 @@ from minerva import parallel_map
|
|||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
class Redshift:
|
||||
def __init__(self, handler, output, db=None, cluster=None):
|
||||
self.handler = handler
|
||||
self.client = handler.session.client("redshift-data")
|
||||
self.output = output
|
||||
self.database = db
|
||||
self.cluster = cluster
|
||||
def __init__(self, handler, output, db=None, cluster=None, workgroup=None, secret=None):
|
||||
self.handler = handler
|
||||
self.client = handler.session.client("redshift-data")
|
||||
self.output = output
|
||||
self.database = db
|
||||
self.cluster = cluster
|
||||
self.workgroup = workgroup
|
||||
self.secret = secret
|
||||
|
||||
if workgroup:
|
||||
cli = handler.session.client("redshift-serverless")
|
||||
wg = cli.get_workgroup(workgroupName=workgroup)
|
||||
self.rpus = wg['workgroup']['baseCapacity']
|
||||
|
||||
def query(self, sql):
|
||||
q = Query(self, sql)
|
||||
|
|
@ -41,31 +48,52 @@ class Execute:
|
|||
self.client = redshift.client
|
||||
self.sql = sql
|
||||
self.info_cache = None
|
||||
self.ds = None
|
||||
self.files = None
|
||||
self.temps = []
|
||||
|
||||
def query(self):
|
||||
return self.sql
|
||||
|
||||
|
||||
def run(self):
|
||||
resp = self.client.execute_statement(Sql=self.query(),
|
||||
Database=self.redshift.database,
|
||||
ClusterIdentifier=self.redshift.cluster)
|
||||
if self.redshift.cluster:
|
||||
resp = self.client.execute_statement(Sql=self.query(),
|
||||
Database=self.redshift.database,
|
||||
ClusterIdentifier=self.redshift.cluster)
|
||||
else:
|
||||
resp = self.client.execute_statement(Sql=self.query(),
|
||||
Database=self.redshift.database,
|
||||
SecretArn=self.redshift.secret,
|
||||
WorkgroupName=self.redshift.workgroup)
|
||||
|
||||
self.query_id = resp['Id']
|
||||
return resp
|
||||
|
||||
|
||||
def status(self):
|
||||
return self.info()['Status']
|
||||
|
||||
|
||||
def info(self):
|
||||
res = self.client.describe_statement(Id=self.query_id)
|
||||
self.info_cache = res
|
||||
return self.info_cache
|
||||
|
||||
|
||||
# Block until the SQL has finished running
|
||||
def finish(self):
|
||||
stat = self.status()
|
||||
while stat in ['SUBMITTED', 'PICKED', 'STARTED']:
|
||||
time.sleep(5)
|
||||
stat = self.status()
|
||||
|
||||
pp.pprint(self.info_cache)
|
||||
self.runtime = self.info_cache['UpdatedAt'] - self.info_cache['CreatedAt']
|
||||
|
||||
if self.redshift.rpus:
|
||||
self.cost = 0.36 * self.redshift.rpus * self.runtime.seconds / 3600.0 # $0.36 / RPU-hour
|
||||
|
||||
return stat # finalized state
|
||||
|
||||
|
||||
|
|
@ -75,12 +103,22 @@ class Query(Execute):
|
|||
def query(self):
|
||||
self.out = os.path.join(self.redshift.output,
|
||||
str(random.random()))
|
||||
query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \
|
||||
f"iam_role default " + \
|
||||
f"format as {self.DATA_STYLE} " + \
|
||||
f"manifest"
|
||||
#query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \
|
||||
# f"iam_role default " + \
|
||||
# f"format as {self.DATA_STYLE} " + \
|
||||
# f"manifest"
|
||||
|
||||
query = f"""
|
||||
create temp table temp_data as {self.sql};
|
||||
unload ('select * from temp_data') to {repr(self.out)}
|
||||
iam_role default
|
||||
format as {self.DATA_STYLE}
|
||||
manifest;
|
||||
drop table temp_data;
|
||||
"""
|
||||
return query
|
||||
|
||||
|
||||
def manifest_files(self):
|
||||
status = self.finish()
|
||||
|
||||
|
|
@ -102,9 +140,29 @@ class Query(Execute):
|
|||
else:
|
||||
return status # canceled or error
|
||||
|
||||
|
||||
def results(self):
|
||||
#local = [self.handler.s3.download(f) for f in self.manifest_files()]
|
||||
local = parallel_map(self.handler.s3.download, self.manifest_files())
|
||||
self.ds = pa.dataset.dataset(local)
|
||||
self.temps = [self.handler.s3.download(f) for f in self.manifest_files()]
|
||||
#local = parallel_map(self.handler.s3.download, self.manifest_files())
|
||||
self.ds = pa.dataset.dataset(self.temps)
|
||||
return self.ds
|
||||
|
||||
|
||||
# Return scalar results
|
||||
# Abstracts away a bunch of keystrokes
|
||||
def scalar(self):
|
||||
return self.results().head(1)[0][0].as_py()
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exception_type, exception_value, exception_traceback):
|
||||
self.close()
|
||||
|
||||
|
||||
def close(self):
|
||||
if self.temps:
|
||||
for file in self.temps:
|
||||
os.remove(file)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue