/*
 * Copyright Debezium Authors.
 *
 * Licensed under the Apache Software License version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0
 */
package io.debezium.pipeline.source.snapshot.incremental;

import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.List;
import java.util.OptionalLong;

import io.debezium.jdbc.JdbcConnection;
import io.debezium.relational.Column;
import io.debezium.relational.RelationalDatabaseConnectorConfig;
import io.debezium.relational.Table;
import io.debezium.spi.schema.DataCollectionId;

/**
 * Builds queries for reading incremental snapshot chunks from a table using row value constructors.
 * <p>
 * On some database engines, these queries result in the most efficient query plans when a suitable index exists, but are only
 * compatible with databases that support ROW() syntax.
 */
public class RowValueConstructorChunkQueryBuilder<T extends DataCollectionId> extends AbstractChunkQueryBuilder<T> {

    // NOTE: MySQL 8.0 is compatible with this syntax, but for now we can't reasonably use it because the query planner doesn't know
    // how to optimally use row value constructors: https://bugs.mysql.com/bug.php?id=111952. In fact, some simple testing shows
    // that MySQL 8.0 (unlike PostgreSQL) _does_ optimally support the query generated by the base class. We can switch MySQL to
    // use RowValueConstructorChunkQueryBuilder once we can show that the query planner yields optimal results with it.

    public RowValueConstructorChunkQueryBuilder(RelationalDatabaseConnectorConfig config,
                                                JdbcConnection jdbcConnection) {
        super(config, jdbcConnection);
    }

    private boolean fallbackToSuper(List<Column> pkColumns) {
        // Note that row value constructors don't work for columns that may contain NULL values. At least for PostgreSQL,
        // there's a known alternative, but it involves creating a composite type first. For now, we just use this ROW()
        // syntax only for NOT NULL columns, since non-nullable columns should be the common case. See:
        // https://issues.redhat.com/browse/DBZ-5071?focusedId=22499985&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-22499985
        return pkColumns.stream().anyMatch(Column::isOptional);
    }

    private void rowValueComparison(List<Column> pkColumns, String operator, StringBuilder sql) {
        // While PostgreSQL is tolerant of row value constructors with a single column, MySQL is not. If you're going to use the
        // ROW() syntax on MySQL, you have to provide at least two columns.
        if (pkColumns.size() > 1) {
            sql.append("ROW(");
        }
        for (int i = 0; i < pkColumns.size(); i++) {
            final boolean isLastIterationForI = (i == pkColumns.size() - 1);
            final String pkColumnName = jdbcConnection.quoteIdentifier(pkColumns.get(i).name());
            sql.append(pkColumnName);
            if (!isLastIterationForI) {
                sql.append(", ");
            }
        }
        if (pkColumns.size() > 1) {
            sql.append(')');
        }
        sql.append(' ');
        sql.append(operator);
        sql.append(' ');
        if (pkColumns.size() > 1) {
            sql.append("ROW(");
        }
        for (int i = 0; i < pkColumns.size(); i++) {
            final boolean isLastIterationForI = (i == pkColumns.size() - 1);
            sql.append("?");
            if (!isLastIterationForI) {
                sql.append(", ");
            }
        }
        if (pkColumns.size() > 1) {
            sql.append(")");
        }
    }

    @Override
    protected void addLowerBound(IncrementalSnapshotContext<T> context, Table table, Object[] boundaryKey, StringBuilder sql) {
        // Use ROW() syntax to compare multiple columns at once. This helps query planners use an index.
        //
        // Example: ROW(k1, k2, k3) > ROW(?, ?, ?)
        final List<Column> pkColumns = getQueryColumns(context, table);
        if (fallbackToSuper(pkColumns)) {
            // Fall back to slower base class implementation that is correct for NULL values.
            super.addLowerBound(context, table, boundaryKey, sql);
            return;
        }

        rowValueComparison(pkColumns, ">", sql);
    }

    @Override
    protected void addUpperBound(IncrementalSnapshotContext<T> context, Table table, Object[] boundaryKey, StringBuilder sql) {
        // Use ROW() syntax to set an upper bound.
        //
        // Example: ROW(k1, k2, k3) <= ROW(?, ?, ?)
        //
        // NOTE: PostgreSQL 13 is known to create a bad query plan if we fall back to "NOT addLowerBound()". That is:
        // NOT ROW(k1, k2, k3) > ROW(?, ?, ?)
        // will create a bad query plan.
        final List<Column> pkColumns = getQueryColumns(context, table);
        if (fallbackToSuper(pkColumns)) {
            // Fall back to slower base class implementation that is correct for NULL values.
            super.addUpperBound(context, table, boundaryKey, sql);
            return;
        }

        rowValueComparison(pkColumns, "<=", sql);
    }

    @Override
    public PreparedStatement readTableChunkStatement(IncrementalSnapshotContext<T> context, Table table, String sql) throws SQLException {
        final List<Column> queryColumns = getQueryColumns(context, table);
        if (fallbackToSuper(queryColumns)) {
            // Fall back to slower base class implementation that is correct for NULL values.
            return super.readTableChunkStatement(context, table, sql);
        }

        final PreparedStatement statement = jdbcConnection.readTablePreparedStatement(connectorConfig, sql,
                OptionalLong.empty());
        if (context.isNonInitialChunk()) {
            final Object[] maximumKey = context.maximumKey().get();
            final Object[] chunkEndPosition = context.chunkEndPosititon();
            // Fill boundaries placeholders
            int pos = 0;
            for (int i = 0; i < chunkEndPosition.length; i++) {
                jdbcConnection.setQueryColumnValue(statement, queryColumns.get(i), ++pos, chunkEndPosition[i]);
            }
            // Fill maximum key placeholders
            for (int i = 0; i < maximumKey.length; i++) {
                jdbcConnection.setQueryColumnValue(statement, queryColumns.get(i), ++pos, maximumKey[i]);
            }
        }
        return statement;
    }
}
