diff --git a/minerva/athena.py b/minerva/athena.py index 0a316e4..0062341 100644 --- a/minerva/athena.py +++ b/minerva/athena.py @@ -19,6 +19,7 @@ class Athena: def __init__(self, handler, output=None): self.handler = handler self.client = handler.session.client("athena") + self.glue = handler.session.client('glue') self.output = output @@ -40,20 +41,38 @@ class Athena: p = Parallelize(self, *args, **kwargs) return p - def delete_table(self, table, join=True): - e = Execute(self, f"drop table {table}") + def delete_table(self, db_name, table, join=True): + table = table.split(".")[-1] + e = Execute(self, f"drop table {db_name}.{table}") e.run() if join: e.finish() + resp = self.glue.batch_delete_table(DatabaseName = db_name, + TablesToDelete = [table]) + s3_uri = os.path.join(self.output, table, "") + #print(f"deleting {s3_uri}") self.handler.s3.rm(s3_uri) return e + def delete_tables(self, db_name, tables): + e = Execute(self, f"drop table {', '.join(tables)}") + e.run() + e.finish() + + self.glue.batch_delete_table(DatabaseName = db_name, + TablesToDelete = tables) + + for table in tables: + s3_uri = os.path.join(self.output, table, "") + #print(f"deleting {s3_uri}") + self.handler.s3.rm(s3_uri) + def cancel(self, query_id): - return self.client.stop_query_execution(QueryExecutionId=query_id) + return self.client.stop_query_execution(QueryExecutionId = query_id) class Execute: @@ -227,6 +246,7 @@ class Parallelize: def __init__(self, athena, dest=None, data=[], n=1): self.athena = athena self.dest = dest + self.db = dest and dest.split('.')[0] self.n = n self.tables = [] self.queries = [] @@ -235,7 +255,7 @@ class Parallelize: if type(data) == type(1): self.data = list(range(data)) - elif type(data) == type([]) or data.__iter__: + elif type(data) == type([]) or hasattr(data, '__iter__'): self.data = list(data) else: raise Exception(f"Passed in {type(data)}, expected list-like or integer") @@ -262,19 +282,23 @@ class Parallelize: self.clear_temp_tables() raise StopIteration - # temp table name, in case it's needed - tmp = "temp_" + str(round(random.random() * 10_000_000)) - self.tables.append(tmp) - obj = self.groups[self.current] self.current += 1 - return tmp, obj + if self.dest: + # temp table name, in case it's needed + tmp = "temp_" + str(round(random.random() * 10_000_000)) + self.tables.append(tmp) + + return tmp, obj + else: + return obj # runs the SQL and removes the S3 data def clear_temp_tables(self): - qs = [self.athena.delete_table(table, join=False) for table in self.tables] + #pp.pprint(self.tables) + qs = [self.athena.delete_table(self.db, table, join=False) for table in self.tables] for q in qs: q.finish() @@ -286,10 +310,13 @@ class Parallelize: self.cost = sum([q.cost for q in self.queries]) self.runtime = max([q.runtime for q in self.queries]) + def files(self): + return [f for q in self.queries for f in q.results().files] + def results(self): self.finish() - return pa.dataset.dataset([f for q in self.queries for f in q.results().files]) + return pa.dataset.dataset(self.files()) def union_tables(self, dest): diff --git a/pyproject.toml b/pyproject.toml index 5061d4b..72eceb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "minerva" -version = "0.7.6" +version = "0.7.7" description = "Easier access to AWS Athena and Redshift" authors = [ "Ari Brown ",