Source code for dagster_snowflake.snowflake_io_manager
from typing import Sequence
from snowflake.connector import ProgrammingError
from dagster import Field, IOManagerDefinition, OutputContext, StringSource, io_manager
from .db_io_manager import DbClient, DbIOManager, DbTypeHandler, TablePartition, TableSlice
from .resources import SnowflakeConnection
SNOWFLAKE_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
[docs]def build_snowflake_io_manager(type_handlers: Sequence[DbTypeHandler]) -> IOManagerDefinition:
"""
Builds an IO manager definition that reads inputs from and writes outputs to Snowflake.
Args:
type_handlers (Sequence[DbTypeHandler]): Each handler defines how to translate between
slices of Snowflake tables and an in-memory type - e.g. a Pandas DataFrame.
Returns:
IOManagerDefinition
Examples:
.. code-block:: python
from dagster_snowflake import build_snowflake_io_manager
from dagster_snowflake_pandas import SnowflakePandasTypeHandler
snowflake_io_manager = build_snowflake_io_manager([SnowflakePandasTypeHandler()])
@job(resource_defs={'io_manager': snowflake_io_manager})
def my_job():
...
"""
@io_manager(
config_schema={
"database": StringSource,
"account": StringSource,
"user": StringSource,
"password": StringSource,
"warehouse": Field(StringSource, is_required=False),
"schema": Field(StringSource, is_required=False),
}
)
def snowflake_io_manager():
return DbIOManager(type_handlers=type_handlers, db_client=SnowflakeDbClient())
return snowflake_io_manager
class SnowflakeDbClient(DbClient):
@staticmethod
def delete_table_slice(context: OutputContext, table_slice: TableSlice) -> None:
no_schema_config = (
{k: v for k, v in context.resource_config.items() if k != "schema"}
if context.resource_config
else {}
)
with SnowflakeConnection(
dict(schema=table_slice.schema, **no_schema_config), context.log # type: ignore
).get_connection() as con:
try:
con.execute_string(_get_cleanup_statement(table_slice))
except ProgrammingError:
# table doesn't exist yet, so ignore the error
pass
@staticmethod
def get_select_statement(table_slice: TableSlice) -> str:
col_str = ", ".join(table_slice.columns) if table_slice.columns else "*"
if table_slice.partition:
return (
f"SELECT {col_str} FROM {table_slice.database}.{table_slice.schema}.{table_slice.table}\n"
+ _time_window_where_clause(table_slice.partition)
)
else:
return f"""SELECT {col_str} FROM {table_slice.database}.{table_slice.schema}.{table_slice.table}"""
def _get_cleanup_statement(table_slice: TableSlice) -> str:
"""
Returns a SQL statement that deletes data in the given table to make way for the output data
being written.
"""
if table_slice.partition:
return (
f"DELETE FROM {table_slice.database}.{table_slice.schema}.{table_slice.table}\n"
+ _time_window_where_clause(table_slice.partition)
)
else:
return f"DELETE FROM {table_slice.database}.{table_slice.schema}.{table_slice.table}"
def _time_window_where_clause(table_partition: TablePartition) -> str:
start_dt, end_dt = table_partition.time_window
start_dt_str = start_dt.strftime(SNOWFLAKE_DATETIME_FORMAT)
end_dt_str = end_dt.strftime(SNOWFLAKE_DATETIME_FORMAT)
return f"""WHERE {table_partition.partition_expr} BETWEEN '{start_dt_str}' AND '{end_dt_str}'"""