Tag Archives: how etl databricks

A generic python based ETL pipeline solution for Databricks

Below is the code necessary to create a Databricks notebook source file that can be imported into Databricks. This file can act as a template for creating ETL logic to build tables in Databricks. Once the notebook is prepared it can be set to run by a Databricks workflow job.

The template is parameterized. This means the developer just needs to provide the destination database, the destination schema, the destination table and the SQL logic.

(Note: this simple example is a full load solution and not a incremental load solution. An incremental load solution can be achieve by writing sufficiently robust SQL that is use case specific.)

The SQL is provided as a variable and the table or table names are stored in a list allowing for a large degree of flexibility for creating a single pipeline that builds multiple database objects.

Another important feature of the code is that it compensates for the fact that Databricks does not have a native acknowledgement of primary keys or restrictions on their violations. A list of primary keys can be provided and if any of those keys are null or not distinct the code will throw an error.

The code will also assign metadata fields to each record created including the job run id as the ETL id, the created date and the updated date.

# Databricks notebook source
# MAGIC %md
# MAGIC https://tidbytez.com/<br />
# MAGIC This is an ETL notebook.<br />

# COMMAND ----------

# Libraries
import os
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
from pyspark.sql.functions import lit, concat_ws, isnan, when, count, col
from datetime import datetime, timedelta

# COMMAND ----------

# Functions

# Generate ETL ID
def get_etl_id():
    try:
        run_id = (
            dbutils.notebook.entry_point.getDbutils()
            .notebook()
            .getContext()
            .currentRunId()
            .toString()
        )
        if run_id.isdigit():
            etl_id = bigint(run_id)
            return etl_id
        else
            etl_id = bigint(1)
            return etl_id
    except:
        print("Could not return an etl_id number")


# Build database object
def build_object(dest_db_name, schema_name, table_name, pk_fields, sql_query):

    # Destination Database and table
    table_location = dest_db_name + "." + table_name
    # External table file location
    file_location = "/mnt/" + schema_name + "/" + table_name

    # Create Dataframe
    df = sql_query

    # Count nulls in Primary Key
    cnt_pk_nulls = df.select(
        [count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in pk_fields]
    ).collect()[0][0]
    # Dataframe record count
    cnt_rows = df.count()
    # Primary Key distinct count
    cnt_dist = df.dropDuplicates(pk_fields).count()
    # Error message
    message = ""

    # Join metadata to dataframe
    global meta_df
    meta = meta_df
    df = df.withColumn("key", lit(1))

    # inner join on two dataframes
    df = df.join(meta, df.key == meta.key, "inner").drop(df.key).drop(meta.key)

    # Write dataframe to table
    if cnt_pk_nulls == 0:
        if cnt_rows == cnt_dist:
            df.write.mode("overwrite").format("delta").option(
                "mergeSchema", "false"
            ).option("path", file_location).saveAsTable(table_location)
        else:
            message = "Primary Key is not unique"
    else:
        message = "Primary Key contains nulls"

    if message != "":
        raise Exception(message)


# COMMAND ----------

# Variables

# Destinations

# File location
schema_name = "YOUR_SCHEMA_NAME_HERE"

# Database location
dest_db_name = "YOUR_DEST_DATABASE_NAME_HERE"

# PK fields
pk_fields = ["EXAMPLE_ID", "EXAMPLE_LOCATION"]

# Metadata
etl_id = get_etl_id()
t = datetime.utcnow()

# Create metadata dataFrame
data = [(1, etl_id, t, t)]
columns = ["key", "ETL_ID", "CREATED_DATE", "UPDATED_DATE"]
meta_df = spark.createDataFrame(data, columns)
meta_df = meta_df.withColumn("ETL_ID", meta_df["ETL_ID"].cast("int"))

# COMMAND ----------

# Table name variable list
table_list = [
    {"table_name": "EXAMPLE_TABLE"}
]

# COMMAND ----------

# Iterate through table variables
for i in range(len(table_list)):

    table_name = table_list[i].get("table_name")

    # SQL query
    sql_query = spark.sql(
        f"""
        SELECT 1 AS EXAMPLE_ID,
        'TEXAS' AS "EXAMPLE_LOCATION"
        """
    )
    build_object(dest_db_name, schema_name, table_name, pk_fields, sql_query)