带着爱和梦想去生活

Spark_Hive_RDBMS读写操作

· Read in about 9 min · (1786 Words)

[项目总结提炼]

前面我们在做数据工程化过程中会大量用到数据的读写操作,现总结如下!!!

主要有以下几个文件:

config.ini                         配置文件

func_forecast.sh                   工程执行文件

mysql-connector-java-8.0.11.jar    MySQL连接jdbc驱动

predict_pred.py                    执行主代码

utils.py                           工具项代码

1 config.ini

主要定义参数值

[spark]
executor_memory = 2g
driver_memory = 2g
sql_execution_arrow_enabled = true
executor_cores = 2
executor_instances = 2
jar = /root/spark-rdbms/mysql-connector-java-8.0.11.jar

[mysql]
host = 11.23.32.16
port = 3306
db = test
user = test
password = 123456

[postgresql]
host = 11.23.32.16
port = 3306
db = test
user = test
password = 123456

[dbms_parameters]
mysql_conn_info_passwd @@@ =4bI=prodha
mysql_conn_info_database = test_db

2 utils.py

主要定义各个函数

from configparser import ConfigParser
from pyspark import SparkConf
import psycopg2
import pymysql

# pip install pymysql
# pip install psycopg2-binary


def get_config():
    """
    获取整个配置文件信息
    
    Returns
    -------
        cf: ConfigParser
            配置文件信息
    """
    cf = ConfigParser()
    cf.read('config.ini', encoding='utf-8')
    return cf

def get_user_pwd():
    """
    获取用户、密码
    
    Returns
    -------
        strUser: Str
            连接用户
        strPassword: Str
            连接密码
    """
    cf = get_config()
    
    strUser = cf.get('mysql', 'user')
    strPassword = cf.get('mysql', 'password')
    
    return strUser, strPassword

def get_conn_url():
    """
    通过jdbc方式获取连接URL
    
    Returns
    -------
        strUrl: Str
            连接URL
    """
    cf = get_config()
    
    host = cf.get('mysql', 'host')
    port = cf.get('mysql', 'port')
    db = cf.get('mysql', 'db')
    strUrl = f'mysql://{host}:{port}/{db}?useSSL=false'
    return strUrl

def get_conn_properties():
    """
    获取连接用户名+密码字典
    
    Returns
    -------
        strProperties: Dict
            用户、密码组成的字典
    """
    cf = get_config()
    
    user = cf.get('mysql', 'user')
    password = cf.get('mysql', 'password')
    strProperties = {'user': user, 'password': password}
    
    return strProperties

def get_spark_conf():
    """
    获取SparkConf
    
    Returns
    -------
        sparkConf: pyspark.sql.SparkConf
            spark配置参数
    """
    cf = get_config()
    sparkConf = SparkConf()
    sparkConf.set("spark.executor.memoryOverhead", "1g")
    sparkConf.set("spark.driver.maxResultSize", "4g")
    sparkConf.set('spark.some.config.option', 'some-value')
    sparkConf.set('spark.executor.memory', cf.get('spark', 'executor_memory'))
    sparkConf.set('spark.driver.memory', cf.get('spark', 'driver_memory'))
    sparkConf.set('spark.executor.instances', cf.get('spark', 'executor_instances'))
    sparkConf.set('spark.executor.cores', cf.get('spark', 'executor_cores'))
    sparkConf.set('spark.sql.execution.arrow.enabled', cf.get('spark', 'sql_execution_arrow_enabled'))
    
    return sparkConf

def read_dataset(spark, strURL, strDBTable, strUser, strPassword, isConcurrent,
                 partitionColumn='', lowerBound='', upperBound='', numPartitions=1):
    """
    Spark读取数据 jdbc
    
    Parameters
    ----------
        spark: pyspark.sql.SparkSession
            SparkSession
        strURL: Str
            连接URL
        strDBTable: Str
            表名
        strUser: Str
            用户
        strPassword: Str
            密码
        isConcurrent: Bool
            是否并发读取(默认并发为1)
        partitionColumn: Str
            Must be a numeric, date, or timestamp column from the table.
        lowerBound: Str
            Used to decide the partition stride.
        upperBound: Str
            Used to decide the partition stride.
        numPartitions: Str
            The maximum number of partitions that can be used for parallelism
    Returns
    -------
        pyspark.sql.DataFrame
    """
    if isConcurrent:
        jdbcDf=spark.read.format("jdbc") \
                         .option('url', f"jdbc:{strURL}") \
                         .option('dbtable', strDBTable) \
                         .option('user', strUser) \
                         .option('password', strPassword) \
                         .option('partitionColumn', partitionColumn) \
                         .option('lowerBound', lowerBound) \
                         .option('upperBound', upperBound) \
                         .option('numPartitions', numPartitions) \
                         .load()
        return jdbcDf
    else:
        jdbcDf=spark.read.format("jdbc") \
                         .option('url', f"jdbc:{strURL}") \
                         .option('dbtable', strDBTable) \
                         .option('user', strUser) \
                         .option('password', strPassword) \
                         .load()
        return jdbcDf

def write_dataset(srcDF, strURL, desDBTable, dictProperties, mode="append"):
    """
    Spark写入数据 jdbc
    
    Parameters
    ----------
        srcDF: pyspark.sql.DataFrame
            待写入源DataFrame
        strURL: Str
            连接URL
        desDBTable: Str
            写入数据库的表名
        dictProperties: Dict
            用户密码组成的字典
        mode: Str
            插入模式("append": 增量更新; "overwrite": 全量更新)

    """
    try:
        srcDF.write.jdbc(
            f'jdbc:{strURL}', 
            table=desDBTable, 
            mode=mode, 
            properties=dictProperties
        )
    except BaseException as e:
        print(repr(e))

def psycopg_execute(strSql):
    """
    使用psycopg2方式执行sql
    
    Parameters
    ----------
        strSql: Str
            待执行的sql
    """
    cf = get_config()
    
    # 连接数据库
    conn = psycopg2.connect(
        dbname=cf.get('postgresql', 'db'),
        user=cf.get('postgresql', 'user'),
        password=cf.get('postgresql', 'password'), 
        host=cf.get('postgresql', 'host'), 
        port=cf.get('postgresql', 'port')
    )

    # 创建cursor以访问数据库
    cur = conn.cursor()

    try:
        # 执行操作
        cur.execute(strSql)
        # 提交事务
        conn.commit()
    except psycopg2.Error as e:
        print('DML操作异常')
        print(e)
        # 有异常回滚事务
        conn.rollback()
    finally:
        # 关闭连接
        cur.close()
        conn.close()

def get_sql_conn():
    """
    连接sql数据库的engine:可用于使用数据库连接
    
    Parameters
    ----------
        conn: Str
            数据库的engine
    """
    cf = get_config()
    conn = pymysql.connect(
        host=cf.get('mysql', 'host'), 
        port=int(cf.get('mysql', 'port')),
        user=cf.get('mysql', 'user'),
        password=cf.get('mysql', 'password'),
        database=cf.get('mysql', 'db'),
        charset='utf8'   # charset="utf8",编码不要写成"utf-8"
    )
    return conn

def mysql_query_data(strSql):
    """
    使用mysql查询
    
    Parameters
    ----------
        strSql: Str
            待执行的sql

    Returns
    -------
        list[dict]
    """    
    # 连接数据库
    conn = get_sql_conn()

    # 创建cursor以访问数据库
    cur = conn.cursor(pymysql.cursors.DictCursor)

    try:
        # 执行操作
        cur.execute(strSql)
        return cur.fetchall()
    finally:
        # 关闭连接
        cur.close()
        conn.close()

def mysql_insert_or_update_data(strSql):
    """
    使用mysql执行insert或update操作
    
    Parameters
    ----------
        strSql: Str
            待执行的sql
    """    
    # 连接数据库
    conn = get_sql_conn()

    # 创建cursor以访问数据库
    cur = conn.cursor()

    try:
        # 执行操作
        cur.execute(strSql)
        # 提交事务
        conn.commit()
    except Exception:
        print('DML操作异常')
        # 有异常回滚事务
        conn.rollback()
    finally:
        # 关闭连接
        cur.close()
        conn.close()

3 postgresql读/写/插入/更新

import os
import io
import re
import psycopg2
import numpy as np
import pandas as pd
from datetime import datetime
from sqlalchemy import create_engine
from configparser import ConfigParser
import subprocess
# subprocess 模块允许你生成新的进程,连接它们的输入、输出、错误管道,并且获取它们的返回码。
# 推荐的调用子进程的方式是在任何它支持的用例中使用 run() 函数。对于更进阶的用例,也可以使用底层的 Popen 接口。

def is_windows():
    import sys
    return sys.platform.startswith('win')

# 解密
def decrypt_password(encrypted_password: str):
    password_command = 'password'
    cmd = subprocess.Popen(f'{password_command} decrypt {encrypted_password}', stdin=subprocess.PIPE,
                           stderr=subprocess.PIPE,
                           stdout=subprocess.PIPE, universal_newlines=True, shell=True, bufsize=1)
    return '\n'.join(cmd.stdout.readlines())

# 加密
def encrypt_password(decrypted_password: str):
    password_command = 'password'
    cmd = subprocess.Popen(f'{password_command} encrypt {decrypted_password}', stdin=subprocess.PIPE,
                           stderr=subprocess.PIPE,
                           stdout=subprocess.PIPE, universal_newlines=True, shell=True, bufsize=1)
    stderr_message = '\n'.join(cmd.stderr.readlines())
    if stderr_message.strip() != '':
        raise Exception(stderr_message)
    return '\n'.join(cmd.stdout.readlines())


def get_config_file_path():
    path = os.path.abspath(os.path.join(os.getcwd(), ".."))
    return f'{path}/algo_config/config.ini'


def get_relative_config_file_path():
    return "../algo_config/config.ini"


def get_config():
    # 获取根目录路径
    path = get_config_file_path()

    # 初始化配置文件对象
    cf = ConfigParser()
    cf.read(path, encoding='utf-8')
    return cf


def get_env(env: str = None):
    # 获取配置文件对象
    cf = get_config()

    # 初始化环境变量, 默认default
    env = 'default' if env is None else env

    # 读取当前环境变量配置文件, 若存在则修改环境变量的初始值
    if os.path.exists('/data/env_conf/env_name.conf'):
        with open('/data/env_conf/env_name.conf') as env_f:
            env = env_f.read().strip()

    return cf, env


def get_user_pwd(env: str = None):
    # 初始化配置文件对象及环境变量值
    cf, env = get_env(env)

    # 分区拼接
    partition = f'postgresql_{env}'

    # 获取用户名及密码
    user = cf.get(partition, 'user')
    password = decrypt_password(cf.get(partition, 'password'))
    # print(f'current env: {env}, read partition: {partition}, user: {user}, password: {password}')

    return user, password


def get_conn_info(env=None):
    # 初始化配置文件对象及环境变量值
    cf, env = get_env(env)

    # 分区拼接
    partition = f'postgresql_{env}'

    # 获取数据库的IP地址、端口号、以及要连接的数据库
    host = cf.get(partition, 'host')
    port = cf.get(partition, 'port')
    db = cf.get(partition, 'db')

    # print(f'current env: {env}, read partition: {partition}, host: {host}, port: {port}, db: {db}')
    return host, port, db


def replace_value(x):
    if x == "''::character varying":
        return ' '
    elif x == 'now()':
        return datetime.now()
    elif x is np.nan:
        return x
    elif x is None:
        return x
    elif x.startswith('('):
        return re.findall(r'\((.*?)\)', x)[0]
    else:
        return x


def read_dataset(sql):
    # 配置文件信息读取
    user, password = get_user_pwd()
    host, port, db = get_conn_info()

    # 连接数据库
    conn = psycopg2.connect(database=db, user=user, password=password, host=host, port=port)
    df = pd.read_sql(sql, con=conn)
    return df


def write_dataset(df, schema, table_name, if_exists='fail'):
    """
        if_exists mode:
            fail: 数据库不存在目标表时, 根据待写入数据创建对应的数据表, 数据可正常写入, 表存在则写入失败
            append:  数据库中存在目标表, 将数据追加到目标表中
            replace: 数据库中存在目标表, 将目标表中的数据替换为当前数据

        利用copy的方式将数据写入目标表
    """

    # 配置文件信息读取
    user, password = get_user_pwd()
    host, port, db = get_conn_info()

    # 连接数据库
    db_engine = create_engine(f'postgresql://{user}:{password}@{host}:{port}/{db}')
    string_data_io = io.StringIO()
    df.to_csv(string_data_io, sep=',', index=False, header=True)
    pd_sql_engine = pd.io.sql.pandasSQL_builder(db_engine)
    table = pd.io.sql.SQLTable(table_name, pd_sql_engine, frame=df, index=False, if_exists=if_exists, schema=schema)
    table.create()
    string_data_io.seek(0)
    with db_engine.connect() as conn:
        with conn.connection.cursor() as cursor:
            copy_cmd = f"COPY {schema}.{table_name} FROM STDIN WITH HEADER DELIMITER ',' CSV"
            cursor.copy_expert(copy_cmd, string_data_io)
        conn.connection.commit()


def insert_df_to_table(df, schema, table_name, write_flag="append"):
    # 获取当前表的所有字段
    sql = f"""select column_name, column_default from information_schema.columns where table_schema = '{schema}' and 
    table_name = '{table_name}' order by ordinal_position """
    field_df = read_dataset(sql)
    field_df['column_default'] = field_df['column_default'].apply(lambda x: replace_value(x))
    field_dict = dict(zip(field_df['column_name'], field_df['column_default']))
    field_list = field_df['column_name'].tolist()

    # 获取当前df的所有字段
    columns_list = list(df.columns)

    # 获取当前df缺少的字段
    other_list = list(set(field_list) - set(columns_list))

    # 为缺少字段填充空值
    for x in other_list:
        value = field_dict[x]
        df[x] = value

    # 修正列的顺序
    df = df[field_list]

    # 数据写入
    write_dataset(df, schema, table_name, write_flag)
    
    
if __name__ == '__main__':
    user,passwd = get_user_pwd(env = 'dev')
    print(passwd)
def init_table():
    sql = f"""truncate arp_dw1.tb3_bom_safety_df;"""

    connect = psycopg2.connect(dbname=db, user=user, password=password, host=host, port=port)
    cur = connect.cursor()
    try:
        cur.execute(sql)
        connect.commit()
    except BaseException as e:
        print(repr(e))
        connect.rollback()
    connect.close()

def sku_meta_data_insert(result_list):
    # insert meta data
    # 创建数据库连接对象
    connect = psycopg2.connect(database=db, user=user, password=password, host=host, port=port)
    cur = connect.cursor()

    sql = """insert into arp_app.alg_sku_grouping(sku_grouping_name, sku_grouping_id, status, create_by) values(%s, 
        %s, %s, %s) """
    try:
        cur.executemany(sql, result_list)
        connect.commit()
    except BaseException as e:
        print(repr(e))
        connect.rollback()
    connect.close()


def update_table(df, schema, table):
    write_dataset(df, 'arp_dw1', 'temp_sku_group_inc', 'fail')

    # 创建数据库连接对象
    connect = psycopg2.connect(database=db, user=user, password=password, host=host, port=port)
    cur = connect.cursor()

    sql = f"""insert into {schema}.{table}(unit, primary_category, secondary_category, material_code, 
    sku_grouping_id, status, create_by) select * from arp_dw1.temp_sku_group_inc"""

    try:
        cur.execute(sql)
        connect.commit()
    except BaseException as e:
        print(repr(e))
        connect.rollback()

    drop_sql = """drop table arp_dw1.temp_sku_group_inc"""

    try:
        cur.execute(drop_sql)
        connect.commit()
    except BaseException as e:
        print(repr(e))
        connect.rollback()
    connect.close()

4 测试代码 predict_pred.py

from pyspark.sql import SparkSession
from pyspark.sql import functions as fn
from pyspark.ml import Pipeline, util
from pyspark.ml import feature as ft

import pandas as pd
import time

import sys
print(sys.path)
sys.path.append('/root/spark-rdbms') #把模块目录加到sys.path列表中
print(sys.path)

from utils import get_config
from utils import get_user_pwd
from utils import get_conn_url
from utils import get_conn_properties
from utils import get_spark_conf
from utils import read_dataset
from utils import write_dataset
from utils import get_sql_conn
from utils import mysql_query_data
from utils import mysql_insert_or_update_data


if __name__ == '__main__':
    # (1) spark读取hive
    # 启动spark
    spark = SparkSession.\
        builder.\
        appName("spark_io").\
        config(conf=get_spark_conf()).\
        enableHiveSupport().\
        getOrCreate()
    
    select_dt = "2020-09-09"
    strSql = f"""
        select * from yth_src.kclb where dt = '{select_dt}'
    """
    df = spark.sql(strSql)
    print(df.show())


    # (2) spark写入hive
    # dt = time.strftime("%Y-%m-%d", time.localtime()) 
    # print(dt)
    # # 打开动态分区
    # spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
    # spark.sql("set hive.exec.dynamic.partition=true")
    # spark.sql(f"""
    # insert overwrite table yth_dw.shop_inv_turnover_rate partition (dt)
    # select 
    #     shop_code, 
    #     pro_code,

    #     current_timestamp as created_time, 
    #     dt = '{dt}'
    # from shop_inv_turnover_rate_table_db
    # """)

    # (2) spark读取MySQL
    # 获取用户密码
    strUser, strPwd = get_user_pwd()
    jdbcDf = read_dataset(
        spark = spark, 
        strURL = get_conn_url(), 
        strDBTable = 'hive_kclb', 
        strUser = strUser, 
        strPassword = strPwd,
        isConcurrent = False
    )
    print(jdbcDf.show())

    # (2) spark写入mysql
    # write_dataset(
    #     srcDF = jdbcDf,
    #     strURL = get_conn_url(),
    #     table='hive_kclb2', 
    #     mode='append', 
    #     properties=get_conn_properties()
    # )

    # (3) pymysql读取MySQL
    strSql2 = f"""
        select * from hive_kclb limit 5
    """
    df2 = mysql_query_data(strSql2)   # 返回 list[dict]
    df2 = pd.DataFrame(df2)
    print(df2)

    # (3) pymysql读取MySQL————pandas读取
    engine = get_sql_conn()
    df3 = pd.read_sql(strSql2,con = engine)
    print(df3)

补充spark jdbc方式读取扩展

from pyspark.sql import SparkSession
import sys

from utils import get_spark_conf
from utils import get_user_pwd
from utils import get_conn_url
from utils import read_dataset

if __name__ == "__main__":
    select_dt = sys.argv[1]
    print("分区:",select_dt)

    # spark初始化
    spark = SparkSession.\
                builder.\
                appName("spark_job").\
                config(conf=get_spark_conf()).\
                enableHiveSupport().\
                getOrCreate()

    strUser, strPwd = get_user_pwd()
    mysql_query_1 = f"""
        (
	select 
            t1.purchase_date, 
            t1.shop_code, 
            t3.shop_name, 
            t2.create_datetime, 
            t2.status, 
            if(t2.status=1,'已发布','未发布') as issue_status
        from (
            select distinct purchase_date,shop_code 
            from prod_yth_hn.shop_replenishment 
        ) t1 
        left join (
            select * 
            from prod_yth_hn.shop_replenishment_info 
        ) t2 on t2.shop_code=t1.shop_code and t2.purchase_date=t1.purchase_date
        left join prod_yth_hn.hive_shop_attribute t3 
        on t3.shop_code=t1.shop_code
	) as res
        """

    mysql_df_1 = read_dataset(
        spark = spark,
        strURL = get_conn_url(),
        strDBTable = mysql_query_1,
        strUser = strUser,
        strPassword = strPwd,
        isConcurrent = False
    )

    print(type(mysql_df_1))
    print(mysql_df_1.show())

spark-submit方式提交,可以添加Python解释器和MySQL的jdbc,JAR包

#!/usr/bin/bash
base_dir="/root/spark-jdbc-test"

# spark解释器
spark_interpreter="/opt/spark-2.4.4/bin/spark-submit"

# python解释器
python_interpreter="/app/anaconda3/bin/python"

if [[ $# -eq 1 ]]; then
    select_dt=$1
else
    select_dt=$(date -d "1 day ago" +"%Y-%m-%d")
fi
echo "select dt is :${select_dt}"

# 执行程序
${spark_interpreter} --master yarn \
                    --num-executors 4 \
                    --conf spark.pyspark.python=${python_interpreter} \
                    --executor-memory 4g \
                    --driver-class-path mysql-connector-java-8.0.21.jar \
                    --jars mysql-connector-java-8.0.21.jar ${base_dir}/hn_arp_desc_for_store_wh.py "${select_dt}"

5 执行代码 func_forecast.sh

# !/bin/bash

# spark解释器
spark_interpreter="/opt/spark-2.4.4/bin/spark-submit"

# python解释器
python_interpreter="/app/anaconda3/bin/python"

function func_forecast()
{
    cd /root/spark-rdbms

    # (1) 
    # --packages  jar包的maven地址
    # --packages  mysql:mysql-connector-java:5.1.27 --repositories http://maven.aliyun.com/nexus/content/groups/public/
    #                                               --repositories 为mysql-connector-java包的maven地址,若不给定,则会使用该机器安装的maven默认源中下载
    # 若依赖多个包,则重复上述jar包写法,中间以逗号分隔
    # 默认下载的包位于当前用户根目录下的.ivy/jars文件夹中
    # 应用场景:本地可以没有,集群中服务需要该包的的时候,都是从给定的maven地址,直接下载

    # (2)
    # --jars JARS
    # 应用场景:要求本地必须要有对应的jar文件
    # --driver-class-path 作用于driver的额外类路径,使用–jar时会自动添加路径,多个包之间用冒号(:)分割

    # 执行程序
    ${spark_interpreter} --master yarn \
                        --conf spark.pyspark.python=${python_interpreter} \
                        --driver-class-path mysql-connector-java-8.0.11.jar \
                        --jars mysql-connector-java-8.0.11.jar predict_pred.py

    if [[ $? -ne 0 ]]; then
        echo "--> 执行失败"
    exit 1
    fi
}
func_forecast


6 Shell中配置文件的读取

编写统一的读取文件ini_pro.sh

###
 # @Author: ydzhao
 # @Description: read configuration file's parameters
 # @Date: 2020-07-27 13:10:31
 # @LastEditTime: 2020-09-13 07:33:14
 # @FilePath: /git/code/ydzhao/spark-hive-rdbms/conf/ini.sh
### 
#!/usr/bin/bash

INIFILE=$1
SECTION=$2
ITEM=$3
SPLIT=$4
NEWVAL=$5


function ReadINIfile(){ 
  ReadINI=`awk -F $SPLIT '/\['$SECTION'\]/{a=1}a==1&&$1~/'$ITEM'/{print $2;exit}' $INIFILE`
  echo $ReadINI
}


function WriteINIfile(){
   WriteINI=`sed -i "/^\[$SECTION\]/,/^\[/ {/^\[$SECTION\]/b;/^\[/b;s/^$ITEM*=.*/$ITEM=$NEWVAL/g;}" $INIFILE`
  echo $WriteINI
}

if [ "$5" = "" ] ;then 
   ReadINIfile $1 $2 $3 $4
else
   WriteINIfile $1 $2 $3 $4 $5
fi

# ./ini.sh $1 $2 $3 $4       读取ini
# ./ini.sh $1 $2 $3 $4 $5    写入ini "newval"

我们读取上面的config.ini文件

可以这样写

[root@ythbdnode01 ~]# mysql_conn_info_user=`bash /root/ini_pro.sh /root/config.ini dbms_parameters mysql_conn_info_user =`
[root@ythbdnode01 ~]# echo $mysql_conn_info_user
test_db

这里传入4个变量分别是:

  • INIFILE 配置文件名
  • SECTION 某个区块
  • ITEM 某个区块下的明细
  • SPLIT 分隔符
[root@ythbdnode01 ~]# mysql_conn_info_passwd=`bash /root/ini_pro.sh /root/config.ini dbms_parameters mysql_conn_info_passwd @@@`
[root@ythbdnode01 ~]# echo $mysql_conn_info_passwd
=4bI=prodha