R, Python, DB 備忘録

データベースとか、jupyter(Python)、Rとか色々

pandas only support SQLAlchemy connectable

現象

いつのバージョンアップからかはよくわからないが、最近pandasの最新版をインストールしたところ、接続情報conpyodbcのconnectionを使うとタイトルのUserWarningが表示されるようになった。

再現イメージ

import pandas as pd
import pyodbc
cnxn = pyodbc.connect(*******)

pd.read_sql_query(******, cnxn)

原因

原因は警告メッセージに表示される通りで、接続情報conとして認められるのは以下の3種類のみ
1. SQLAlchemyによる接続
1. 接続文字列
1. sqlite3(DBAPI2)

対処(不完全)

SQLAlchemyを使え、というのだから使えばよい。

from sqlalchemy import create_engine
cnxn = create_engine('awsathena+rest://{aws_access_key_id}:{aws_secret_access_key}@athena.{region_name}.amazonaws.com/{schema_name}?s3_staging_dir={s3_staging_dir}&...')
# 以下同様

問題点

上記の対処法には問題があり、大きな結果データフレームを取得しようとすると時間がかかりすぎてしまうことである。
以下の記事が参考になるが、結果10万行あたりから対処2(後述)と比べて大きな差(2.5秒 vs 21秒)が出てきている。
medium.com

対処2

Boto3を使い、Select結果をS3に保存しダウンロードする、という方法が実用的
具体的な方法は上記の記事の「Method 2: Use Boto3 and download results file」にも詳しいが、ここではAthenaを使う場合のAWS SSOを前提としたコードを備忘として残しておく。

# 以下をまとめて行うpandas.read_sql_query ライクなUDF read_sql_query の作成例
# boto3.client(athena)でs3上に結果を生成
# boto3.client(s3)でダウンロード
# pandas.read_csv
import subprocess
import boto3

subprocess.call("aws sso login --profile your_profile_name")   # AWS SSO Login(ブラウザ立ち上がる)
boto3.setup_default_session(profile_name = 'your_profile_name')

def read_sql_query(queryStr): 
    # クライアントセッションの開始
    client_athena = boto3.client('athena')
    client_s3 = boto3.client('s3')
    
    # AWS関係の変数の定義
    SCHEMA_NAME = "default"
    S3_BUCKET_NAME = "query-results"
    S3_OUTPUT_DIRECTORY = "shipapa15"
    S3_STAGING_DIR = f"s3://{S3_BUCKET_NAME}/{S3_OUTPUT_DIRECTORY}"
    temp_file_location: str = "C:/Temp/athena_query_results.csv"
    
    query_response = client_athena.start_query_execution(
        QueryString = queryStr,
        QueryExecutionContext = {"Database": SCHEMA_NAME},
        ResultConfiguration = {
            "OutputLocation": S3_STAGING_DIR,
            "EncryptionConfiguration": {"EncryptionOption": "SSE_S3"},
        },
    )
    # クエリの完了まで待機
    while True:
        try:
            # This function only loads the first 1000 rows
            client_athena.get_query_results(
                QueryExecutionId=query_response["QueryExecutionId"]
            )
            break
        except Exception as err:
            if "not yet finished" in str(err):
                time.sleep(1)     # import time 要
            else:
                raise err
    # 結果CSVの取得
    client_s3.download_file(
        S3_BUCKET_NAME,
        f"{S3_OUTPUT_DIRECTORY}/{query_response['QueryExecutionId']}.csv",
        temp_file_location,
    )
    return pd.read_csv(temp_file_location)