SQLAlchemy - performing a bulk upsert (if exists, update, else insert) in postgresql

47,609

Solution 1

There is an upsert-esque operation in SQLAlchemy:

db.session.merge()

After I found this command, I was able to perform upserts, but it is worth mentioning that this operation is slow for a bulk "upsert".

The alternative is to get a list of the primary keys you would like to upsert, and query the database for any matching ids:

# Imagine that post1, post5, and post1000 are posts objects with ids 1, 5 and 1000 respectively
# The goal is to "upsert" these posts.
# we initialize a dict which maps id to the post object

my_new_posts = {1: post1, 5: post5, 1000: post1000} 

for each in posts.query.filter(posts.id.in_(my_new_posts.keys())).all():
    # Only merge those posts which already exist in the database
    db.session.merge(my_new_posts.pop(each.id))

# Only add those posts which did not exist in the database 
db.session.add_all(my_new_posts.values())

# Now we commit our modifications (merges) and inserts (adds) to the database!
db.session.commit()

Solution 2

You can leverage the on_conflict_do_update variant. A simple example would be the following:

from sqlalchemy.dialects.postgresql import insert

class Post(Base):
    """
    A simple class for demonstration
    """

    id = Column(Integer, primary_key=True)
    title = Column(Unicode)

# Prepare all the values that should be "upserted" to the DB
values = [
    {"id": 1, "title": "mytitle 1"},
    {"id": 2, "title": "mytitle 2"},
    {"id": 3, "title": "mytitle 3"},
    {"id": 4, "title": "mytitle 4"},
]

stmt = insert(Post).values(values)
stmt = stmt.on_conflict_do_update(
    # Let's use the constraint name which was visible in the original posts error msg
    constraint="post_pkey",

    # The columns that should be updated on conflict
    set_={
        "title": stmt.excluded.title
    }
)
session.execute(stmt)

See the Postgres docs for more details about ON CONFLICT DO UPDATE.

See the SQLAlchemy docs for more details about on_conflict_do_update.

Side-Note on duplicated column names

The above code uses the column names as dict keys both in the values list and the argument to set_. If the column-name is changed in the class-definition this needs to be changed everywhere or it will break. This can be avoided by accessing the column definitions, making the code a bit uglier, but more robust:

coldefs = Post.__table__.c

values = [
    {coldefs.id.name: 1, coldefs.title.name: "mytitlte 1"},
    ...
]

stmt = stmt.on_conflict_do_update(
    ...
    set_={
        coldefs.title.name: stmt.excluded.title
        ...
    }
)

Solution 3

An alternative approach using compilation extension (https://docs.sqlalchemy.org/en/13/core/compiler.html):

from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert

@compiles(Insert)
def compile_upsert(insert_stmt, compiler, **kwargs):
    """
    converts every SQL insert to an upsert  i.e;
    INSERT INTO test (foo, bar) VALUES (1, 'a')
    becomes:
    INSERT INTO test (foo, bar) VALUES (1, 'a') ON CONFLICT(foo) DO UPDATE SET (bar = EXCLUDED.bar)
    (assuming foo is a primary key)
    :param insert_stmt: Original insert statement
    :param compiler: SQL Compiler
    :param kwargs: optional arguments
    :return: upsert statement
    """
    pk = insert_stmt.table.primary_key
    insert = compiler.visit_insert(insert_stmt, **kwargs)
    ondup = f'ON CONFLICT ({",".join(c.name for c in pk)}) DO UPDATE SET'
    updates = ', '.join(f"{c.name}=EXCLUDED.{c.name}" for c in insert_stmt.table.columns)
    upsert = ' '.join((insert, ondup, updates))
    return upsert

This should ensure that all insert statements behave as upserts. This implementation is in Postgres dialect, but it should be fairly easy to modify for MySQL dialect.

Solution 4

I started looking at this and I think I've found a pretty efficient way to do upserts in sqlalchemy with a mix of bulk_insert_mappings and bulk_update_mappings instead of merge.

import time
import sqlite3

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from contextlib import contextmanager


engine = None
Session = sessionmaker()
Base = declarative_base()


def creat_new_database(db_name="sqlite:///bulk_upsert_sqlalchemy.db"):
    global engine
    engine = create_engine(db_name, echo=False)
    local_session = scoped_session(Session)
    local_session.remove()
    local_session.configure(bind=engine, autoflush=False, expire_on_commit=False)
    Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)


@contextmanager
def db_session():
    local_session = scoped_session(Session)
    session = local_session()

    session.expire_on_commit = False

    try:
        yield session
    except BaseException:
        session.rollback()
        raise
    finally:
        session.close()


class Customer(Base):
    __tablename__ = "customer"
    id = Column(Integer, primary_key=True)
    name = Column(String(255))


def bulk_upsert_mappings(customers):

    entries_to_update = []
    entries_to_put = []
    with db_session() as sess:
        t0 = time.time()

        # Find all customers that needs to be updated and build mappings
        for each in (
            sess.query(Customer.id).filter(Customer.id.in_(customers.keys())).all()
        ):
            customer = customers.pop(each.id)
            entries_to_update.append({"id": customer["id"], "name": customer["name"]})

        # Bulk mappings for everything that needs to be inserted
        for customer in customers.values():
            entries_to_put.append({"id": customer["id"], "name": customer["name"]})

        sess.bulk_insert_mappings(Customer, entries_to_put)
        sess.bulk_update_mappings(Customer, entries_to_update)
        sess.commit()

    print(
        "Total time for upsert with MAPPING update "
        + str(len(customers))
        + " records "
        + str(time.time() - t0)
        + " sec"
        + " inserted : "
        + str(len(entries_to_put))
        + " - updated : "
        + str(len(entries_to_update))
    )


def bulk_upsert_merge(customers):

    entries_to_update = 0
    entries_to_put = []
    with db_session() as sess:
        t0 = time.time()

        # Find all customers that needs to be updated and merge
        for each in (
            sess.query(Customer.id).filter(Customer.id.in_(customers.keys())).all()
        ):
            values = customers.pop(each.id)
            sess.merge(Customer(id=values["id"], name=values["name"]))
            entries_to_update += 1

        # Bulk mappings for everything that needs to be inserted
        for customer in customers.values():
            entries_to_put.append({"id": customer["id"], "name": customer["name"]})

        sess.bulk_insert_mappings(Customer, entries_to_put)
        sess.commit()

    print(
        "Total time for upsert with MERGE update "
        + str(len(customers))
        + " records "
        + str(time.time() - t0)
        + " sec"
        + " inserted : "
        + str(len(entries_to_put))
        + " - updated : "
        + str(entries_to_update)
    )


if __name__ == "__main__":

    batch_size = 10000

    # Only inserts
    customers_insert = {
        i: {"id": i, "name": "customer_" + str(i)} for i in range(batch_size)
    }

    # 50/50 inserts update
    customers_upsert = {
        i: {"id": i, "name": "customer_2_" + str(i)}
        for i in range(int(batch_size / 2), batch_size + int(batch_size / 2))
    }

    creat_new_database()
    bulk_upsert_mappings(customers_insert.copy())
    bulk_upsert_mappings(customers_upsert.copy())
    bulk_upsert_mappings(customers_insert.copy())

    creat_new_database()
    bulk_upsert_merge(customers_insert.copy())
    bulk_upsert_merge(customers_upsert.copy())
    bulk_upsert_merge(customers_insert.copy())

The results for the benchmark:

Total time for upsert with MAPPING: 0.17138004302978516 sec inserted : 10000 - updated : 0
Total time for upsert with MAPPING: 0.22074174880981445 sec inserted : 5000 - updated : 5000
Total time for upsert with MAPPING: 0.22307634353637695 sec inserted : 0 - updated : 10000
Total time for upsert with MERGE: 0.1724097728729248 sec inserted : 10000 - updated : 0
Total time for upsert with MERGE: 7.852903842926025 sec inserted : 5000 - updated : 5000
Total time for upsert with MERGE: 15.11970829963684 sec inserted : 0 - updated : 10000
Share:
47,609
mgoldwasser
Author by

mgoldwasser

Data scientist / developer

Updated on November 16, 2021

Comments

  • mgoldwasser
    mgoldwasser over 2 years

    I am trying to write a bulk upsert in python using the SQLAlchemy module (not in SQL!).

    I am getting the following error on a SQLAlchemy add:

    sqlalchemy.exc.IntegrityError: (IntegrityError) duplicate key value violates unique constraint "posts_pkey"
    DETAIL:  Key (id)=(TEST1234) already exists.
    

    I have a table called posts with a primary key on the id column.

    In this example, I already have a row in the db with id=TEST1234. When I attempt to db.session.add() a new posts object with the id set to TEST1234, I get the error above. I was under the impression that if the primary key already exists, the record would get updated.

    How can I upsert with Flask-SQLAlchemy based on primary key alone? Is there a simple solution?

    If there is not, I can always check for and delete any record with a matching id, and then insert the new record, but that seems expensive for my situation, where I do not expect many updates.