Spark_Hive_RDBMS读写操作
[项目总结提炼]
前面我们在做数据工程化过程中会大量用到数据的读写操作,现总结如下!!!
主要有以下几个文件:
config.ini 配置文件
func_forecast.sh 工程执行文件
mysql-connector-java-8.0.11.jar MySQL连接jdbc驱动
predict_pred.py 执行主代码
utils.py 工具项代码
1 config.ini
主要定义参数值
[spark]
executor_memory = 2g
driver_memory = 2g
sql_execution_arrow_enabled = true
executor_cores = 2
executor_instances = 2
jar = /root/spark-rdbms/mysql-connector-java-8.0.11.jar
[mysql]
host = 11.23.32.16
port = 3306
db = test
user = test
password = 123456
[postgresql]
host = 11.23.32.16
port = 3306
db = test
user = test
password = 123456
[dbms_parameters]
mysql_conn_info_passwd @@@ =4bI=prodha
mysql_conn_info_database = test_db
2 utils.py
主要定义各个函数
from configparser import ConfigParser
from pyspark import SparkConf
import psycopg2
import pymysql
# pip install pymysql
# pip install psycopg2-binary
def get_config():
"""
获取整个配置文件信息
Returns
-------
cf: ConfigParser
配置文件信息
"""
cf = ConfigParser()
cf.read('config.ini', encoding='utf-8')
return cf
def get_user_pwd():
"""
获取用户、密码
Returns
-------
strUser: Str
连接用户
strPassword: Str
连接密码
"""
cf = get_config()
strUser = cf.get('mysql', 'user')
strPassword = cf.get('mysql', 'password')
return strUser, strPassword
def get_conn_url():
"""
通过jdbc方式获取连接URL
Returns
-------
strUrl: Str
连接URL
"""
cf = get_config()
host = cf.get('mysql', 'host')
port = cf.get('mysql', 'port')
db = cf.get('mysql', 'db')
strUrl = f'mysql://{host}:{port}/{db}?useSSL=false'
return strUrl
def get_conn_properties():
"""
获取连接用户名+密码字典
Returns
-------
strProperties: Dict
用户、密码组成的字典
"""
cf = get_config()
user = cf.get('mysql', 'user')
password = cf.get('mysql', 'password')
strProperties = {'user': user, 'password': password}
return strProperties
def get_spark_conf():
"""
获取SparkConf
Returns
-------
sparkConf: pyspark.sql.SparkConf
spark配置参数
"""
cf = get_config()
sparkConf = SparkConf()
sparkConf.set("spark.executor.memoryOverhead", "1g")
sparkConf.set("spark.driver.maxResultSize", "4g")
sparkConf.set('spark.some.config.option', 'some-value')
sparkConf.set('spark.executor.memory', cf.get('spark', 'executor_memory'))
sparkConf.set('spark.driver.memory', cf.get('spark', 'driver_memory'))
sparkConf.set('spark.executor.instances', cf.get('spark', 'executor_instances'))
sparkConf.set('spark.executor.cores', cf.get('spark', 'executor_cores'))
sparkConf.set('spark.sql.execution.arrow.enabled', cf.get('spark', 'sql_execution_arrow_enabled'))
return sparkConf
def read_dataset(spark, strURL, strDBTable, strUser, strPassword, isConcurrent,
partitionColumn='', lowerBound='', upperBound='', numPartitions=1):
"""
Spark读取数据 jdbc
Parameters
----------
spark: pyspark.sql.SparkSession
SparkSession
strURL: Str
连接URL
strDBTable: Str
表名
strUser: Str
用户
strPassword: Str
密码
isConcurrent: Bool
是否并发读取(默认并发为1)
partitionColumn: Str
Must be a numeric, date, or timestamp column from the table.
lowerBound: Str
Used to decide the partition stride.
upperBound: Str
Used to decide the partition stride.
numPartitions: Str
The maximum number of partitions that can be used for parallelism
Returns
-------
pyspark.sql.DataFrame
"""
if isConcurrent:
jdbcDf=spark.read.format("jdbc") \
.option('url', f"jdbc:{strURL}") \
.option('dbtable', strDBTable) \
.option('user', strUser) \
.option('password', strPassword) \
.option('partitionColumn', partitionColumn) \
.option('lowerBound', lowerBound) \
.option('upperBound', upperBound) \
.option('numPartitions', numPartitions) \
.load()
return jdbcDf
else:
jdbcDf=spark.read.format("jdbc") \
.option('url', f"jdbc:{strURL}") \
.option('dbtable', strDBTable) \
.option('user', strUser) \
.option('password', strPassword) \
.load()
return jdbcDf
def write_dataset(srcDF, strURL, desDBTable, dictProperties, mode="append"):
"""
Spark写入数据 jdbc
Parameters
----------
srcDF: pyspark.sql.DataFrame
待写入源DataFrame
strURL: Str
连接URL
desDBTable: Str
写入数据库的表名
dictProperties: Dict
用户密码组成的字典
mode: Str
插入模式("append": 增量更新; "overwrite": 全量更新)
"""
try:
srcDF.write.jdbc(
f'jdbc:{strURL}',
table=desDBTable,
mode=mode,
properties=dictProperties
)
except BaseException as e:
print(repr(e))
def psycopg_execute(strSql):
"""
使用psycopg2方式执行sql
Parameters
----------
strSql: Str
待执行的sql
"""
cf = get_config()
# 连接数据库
conn = psycopg2.connect(
dbname=cf.get('postgresql', 'db'),
user=cf.get('postgresql', 'user'),
password=cf.get('postgresql', 'password'),
host=cf.get('postgresql', 'host'),
port=cf.get('postgresql', 'port')
)
# 创建cursor以访问数据库
cur = conn.cursor()
try:
# 执行操作
cur.execute(strSql)
# 提交事务
conn.commit()
except psycopg2.Error as e:
print('DML操作异常')
print(e)
# 有异常回滚事务
conn.rollback()
finally:
# 关闭连接
cur.close()
conn.close()
def get_sql_conn():
"""
连接sql数据库的engine:可用于使用数据库连接
Parameters
----------
conn: Str
数据库的engine
"""
cf = get_config()
conn = pymysql.connect(
host=cf.get('mysql', 'host'),
port=int(cf.get('mysql', 'port')),
user=cf.get('mysql', 'user'),
password=cf.get('mysql', 'password'),
database=cf.get('mysql', 'db'),
charset='utf8' # charset="utf8",编码不要写成"utf-8"
)
return conn
def mysql_query_data(strSql):
"""
使用mysql查询
Parameters
----------
strSql: Str
待执行的sql
Returns
-------
list[dict]
"""
# 连接数据库
conn = get_sql_conn()
# 创建cursor以访问数据库
cur = conn.cursor(pymysql.cursors.DictCursor)
try:
# 执行操作
cur.execute(strSql)
return cur.fetchall()
finally:
# 关闭连接
cur.close()
conn.close()
def mysql_insert_or_update_data(strSql):
"""
使用mysql执行insert或update操作
Parameters
----------
strSql: Str
待执行的sql
"""
# 连接数据库
conn = get_sql_conn()
# 创建cursor以访问数据库
cur = conn.cursor()
try:
# 执行操作
cur.execute(strSql)
# 提交事务
conn.commit()
except Exception:
print('DML操作异常')
# 有异常回滚事务
conn.rollback()
finally:
# 关闭连接
cur.close()
conn.close()
3 postgresql读/写/插入/更新
import os
import io
import re
import psycopg2
import numpy as np
import pandas as pd
from datetime import datetime
from sqlalchemy import create_engine
from configparser import ConfigParser
import subprocess
# subprocess 模块允许你生成新的进程,连接它们的输入、输出、错误管道,并且获取它们的返回码。
# 推荐的调用子进程的方式是在任何它支持的用例中使用 run() 函数。对于更进阶的用例,也可以使用底层的 Popen 接口。
def is_windows():
import sys
return sys.platform.startswith('win')
# 解密
def decrypt_password(encrypted_password: str):
password_command = 'password'
cmd = subprocess.Popen(f'{password_command} decrypt {encrypted_password}', stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE, universal_newlines=True, shell=True, bufsize=1)
return '\n'.join(cmd.stdout.readlines())
# 加密
def encrypt_password(decrypted_password: str):
password_command = 'password'
cmd = subprocess.Popen(f'{password_command} encrypt {decrypted_password}', stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE, universal_newlines=True, shell=True, bufsize=1)
stderr_message = '\n'.join(cmd.stderr.readlines())
if stderr_message.strip() != '':
raise Exception(stderr_message)
return '\n'.join(cmd.stdout.readlines())
def get_config_file_path():
path = os.path.abspath(os.path.join(os.getcwd(), ".."))
return f'{path}/algo_config/config.ini'
def get_relative_config_file_path():
return "../algo_config/config.ini"
def get_config():
# 获取根目录路径
path = get_config_file_path()
# 初始化配置文件对象
cf = ConfigParser()
cf.read(path, encoding='utf-8')
return cf
def get_env(env: str = None):
# 获取配置文件对象
cf = get_config()
# 初始化环境变量, 默认default
env = 'default' if env is None else env
# 读取当前环境变量配置文件, 若存在则修改环境变量的初始值
if os.path.exists('/data/env_conf/env_name.conf'):
with open('/data/env_conf/env_name.conf') as env_f:
env = env_f.read().strip()
return cf, env
def get_user_pwd(env: str = None):
# 初始化配置文件对象及环境变量值
cf, env = get_env(env)
# 分区拼接
partition = f'postgresql_{env}'
# 获取用户名及密码
user = cf.get(partition, 'user')
password = decrypt_password(cf.get(partition, 'password'))
# print(f'current env: {env}, read partition: {partition}, user: {user}, password: {password}')
return user, password
def get_conn_info(env=None):
# 初始化配置文件对象及环境变量值
cf, env = get_env(env)
# 分区拼接
partition = f'postgresql_{env}'
# 获取数据库的IP地址、端口号、以及要连接的数据库
host = cf.get(partition, 'host')
port = cf.get(partition, 'port')
db = cf.get(partition, 'db')
# print(f'current env: {env}, read partition: {partition}, host: {host}, port: {port}, db: {db}')
return host, port, db
def replace_value(x):
if x == "''::character varying":
return ' '
elif x == 'now()':
return datetime.now()
elif x is np.nan:
return x
elif x is None:
return x
elif x.startswith('('):
return re.findall(r'\((.*?)\)', x)[0]
else:
return x
def read_dataset(sql):
# 配置文件信息读取
user, password = get_user_pwd()
host, port, db = get_conn_info()
# 连接数据库
conn = psycopg2.connect(database=db, user=user, password=password, host=host, port=port)
df = pd.read_sql(sql, con=conn)
return df
def write_dataset(df, schema, table_name, if_exists='fail'):
"""
if_exists mode:
fail: 数据库不存在目标表时, 根据待写入数据创建对应的数据表, 数据可正常写入, 表存在则写入失败
append: 数据库中存在目标表, 将数据追加到目标表中
replace: 数据库中存在目标表, 将目标表中的数据替换为当前数据
利用copy的方式将数据写入目标表
"""
# 配置文件信息读取
user, password = get_user_pwd()
host, port, db = get_conn_info()
# 连接数据库
db_engine = create_engine(f'postgresql://{user}:{password}@{host}:{port}/{db}')
string_data_io = io.StringIO()
df.to_csv(string_data_io, sep=',', index=False, header=True)
pd_sql_engine = pd.io.sql.pandasSQL_builder(db_engine)
table = pd.io.sql.SQLTable(table_name, pd_sql_engine, frame=df, index=False, if_exists=if_exists, schema=schema)
table.create()
string_data_io.seek(0)
with db_engine.connect() as conn:
with conn.connection.cursor() as cursor:
copy_cmd = f"COPY {schema}.{table_name} FROM STDIN WITH HEADER DELIMITER ',' CSV"
cursor.copy_expert(copy_cmd, string_data_io)
conn.connection.commit()
def insert_df_to_table(df, schema, table_name, write_flag="append"):
# 获取当前表的所有字段
sql = f"""select column_name, column_default from information_schema.columns where table_schema = '{schema}' and
table_name = '{table_name}' order by ordinal_position """
field_df = read_dataset(sql)
field_df['column_default'] = field_df['column_default'].apply(lambda x: replace_value(x))
field_dict = dict(zip(field_df['column_name'], field_df['column_default']))
field_list = field_df['column_name'].tolist()
# 获取当前df的所有字段
columns_list = list(df.columns)
# 获取当前df缺少的字段
other_list = list(set(field_list) - set(columns_list))
# 为缺少字段填充空值
for x in other_list:
value = field_dict[x]
df[x] = value
# 修正列的顺序
df = df[field_list]
# 数据写入
write_dataset(df, schema, table_name, write_flag)
if __name__ == '__main__':
user,passwd = get_user_pwd(env = 'dev')
print(passwd)
def init_table():
sql = f"""truncate arp_dw1.tb3_bom_safety_df;"""
connect = psycopg2.connect(dbname=db, user=user, password=password, host=host, port=port)
cur = connect.cursor()
try:
cur.execute(sql)
connect.commit()
except BaseException as e:
print(repr(e))
connect.rollback()
connect.close()
def sku_meta_data_insert(result_list):
# insert meta data
# 创建数据库连接对象
connect = psycopg2.connect(database=db, user=user, password=password, host=host, port=port)
cur = connect.cursor()
sql = """insert into arp_app.alg_sku_grouping(sku_grouping_name, sku_grouping_id, status, create_by) values(%s,
%s, %s, %s) """
try:
cur.executemany(sql, result_list)
connect.commit()
except BaseException as e:
print(repr(e))
connect.rollback()
connect.close()
def update_table(df, schema, table):
write_dataset(df, 'arp_dw1', 'temp_sku_group_inc', 'fail')
# 创建数据库连接对象
connect = psycopg2.connect(database=db, user=user, password=password, host=host, port=port)
cur = connect.cursor()
sql = f"""insert into {schema}.{table}(unit, primary_category, secondary_category, material_code,
sku_grouping_id, status, create_by) select * from arp_dw1.temp_sku_group_inc"""
try:
cur.execute(sql)
connect.commit()
except BaseException as e:
print(repr(e))
connect.rollback()
drop_sql = """drop table arp_dw1.temp_sku_group_inc"""
try:
cur.execute(drop_sql)
connect.commit()
except BaseException as e:
print(repr(e))
connect.rollback()
connect.close()
4 测试代码 predict_pred.py
from pyspark.sql import SparkSession
from pyspark.sql import functions as fn
from pyspark.ml import Pipeline, util
from pyspark.ml import feature as ft
import pandas as pd
import time
import sys
print(sys.path)
sys.path.append('/root/spark-rdbms') #把模块目录加到sys.path列表中
print(sys.path)
from utils import get_config
from utils import get_user_pwd
from utils import get_conn_url
from utils import get_conn_properties
from utils import get_spark_conf
from utils import read_dataset
from utils import write_dataset
from utils import get_sql_conn
from utils import mysql_query_data
from utils import mysql_insert_or_update_data
if __name__ == '__main__':
# (1) spark读取hive
# 启动spark
spark = SparkSession.\
builder.\
appName("spark_io").\
config(conf=get_spark_conf()).\
enableHiveSupport().\
getOrCreate()
select_dt = "2020-09-09"
strSql = f"""
select * from yth_src.kclb where dt = '{select_dt}'
"""
df = spark.sql(strSql)
print(df.show())
# (2) spark写入hive
# dt = time.strftime("%Y-%m-%d", time.localtime())
# print(dt)
# # 打开动态分区
# spark.sql("set hive.exec.dynamic.partition.mode=nonstrict")
# spark.sql("set hive.exec.dynamic.partition=true")
# spark.sql(f"""
# insert overwrite table yth_dw.shop_inv_turnover_rate partition (dt)
# select
# shop_code,
# pro_code,
# current_timestamp as created_time,
# dt = '{dt}'
# from shop_inv_turnover_rate_table_db
# """)
# (2) spark读取MySQL
# 获取用户密码
strUser, strPwd = get_user_pwd()
jdbcDf = read_dataset(
spark = spark,
strURL = get_conn_url(),
strDBTable = 'hive_kclb',
strUser = strUser,
strPassword = strPwd,
isConcurrent = False
)
print(jdbcDf.show())
# (2) spark写入mysql
# write_dataset(
# srcDF = jdbcDf,
# strURL = get_conn_url(),
# table='hive_kclb2',
# mode='append',
# properties=get_conn_properties()
# )
# (3) pymysql读取MySQL
strSql2 = f"""
select * from hive_kclb limit 5
"""
df2 = mysql_query_data(strSql2) # 返回 list[dict]
df2 = pd.DataFrame(df2)
print(df2)
# (3) pymysql读取MySQL————pandas读取
engine = get_sql_conn()
df3 = pd.read_sql(strSql2,con = engine)
print(df3)
补充spark jdbc方式读取扩展
from pyspark.sql import SparkSession
import sys
from utils import get_spark_conf
from utils import get_user_pwd
from utils import get_conn_url
from utils import read_dataset
if __name__ == "__main__":
select_dt = sys.argv[1]
print("分区:",select_dt)
# spark初始化
spark = SparkSession.\
builder.\
appName("spark_job").\
config(conf=get_spark_conf()).\
enableHiveSupport().\
getOrCreate()
strUser, strPwd = get_user_pwd()
mysql_query_1 = f"""
(
select
t1.purchase_date,
t1.shop_code,
t3.shop_name,
t2.create_datetime,
t2.status,
if(t2.status=1,'已发布','未发布') as issue_status
from (
select distinct purchase_date,shop_code
from prod_yth_hn.shop_replenishment
) t1
left join (
select *
from prod_yth_hn.shop_replenishment_info
) t2 on t2.shop_code=t1.shop_code and t2.purchase_date=t1.purchase_date
left join prod_yth_hn.hive_shop_attribute t3
on t3.shop_code=t1.shop_code
) as res
"""
mysql_df_1 = read_dataset(
spark = spark,
strURL = get_conn_url(),
strDBTable = mysql_query_1,
strUser = strUser,
strPassword = strPwd,
isConcurrent = False
)
print(type(mysql_df_1))
print(mysql_df_1.show())
spark-submit方式提交,可以添加Python解释器和MySQL的jdbc,JAR包
#!/usr/bin/bash
base_dir="/root/spark-jdbc-test"
# spark解释器
spark_interpreter="/opt/spark-2.4.4/bin/spark-submit"
# python解释器
python_interpreter="/app/anaconda3/bin/python"
if [[ $# -eq 1 ]]; then
select_dt=$1
else
select_dt=$(date -d "1 day ago" +"%Y-%m-%d")
fi
echo "select dt is :${select_dt}"
# 执行程序
${spark_interpreter} --master yarn \
--num-executors 4 \
--conf spark.pyspark.python=${python_interpreter} \
--executor-memory 4g \
--driver-class-path mysql-connector-java-8.0.21.jar \
--jars mysql-connector-java-8.0.21.jar ${base_dir}/hn_arp_desc_for_store_wh.py "${select_dt}"
5 执行代码 func_forecast.sh
# !/bin/bash
# spark解释器
spark_interpreter="/opt/spark-2.4.4/bin/spark-submit"
# python解释器
python_interpreter="/app/anaconda3/bin/python"
function func_forecast()
{
cd /root/spark-rdbms
# (1)
# --packages jar包的maven地址
# --packages mysql:mysql-connector-java:5.1.27 --repositories http://maven.aliyun.com/nexus/content/groups/public/
# --repositories 为mysql-connector-java包的maven地址,若不给定,则会使用该机器安装的maven默认源中下载
# 若依赖多个包,则重复上述jar包写法,中间以逗号分隔
# 默认下载的包位于当前用户根目录下的.ivy/jars文件夹中
# 应用场景:本地可以没有,集群中服务需要该包的的时候,都是从给定的maven地址,直接下载
# (2)
# --jars JARS
# 应用场景:要求本地必须要有对应的jar文件
# --driver-class-path 作用于driver的额外类路径,使用–jar时会自动添加路径,多个包之间用冒号(:)分割
# 执行程序
${spark_interpreter} --master yarn \
--conf spark.pyspark.python=${python_interpreter} \
--driver-class-path mysql-connector-java-8.0.11.jar \
--jars mysql-connector-java-8.0.11.jar predict_pred.py
if [[ $? -ne 0 ]]; then
echo "--> 执行失败"
exit 1
fi
}
func_forecast
6 Shell中配置文件的读取
编写统一的读取文件ini_pro.sh
###
# @Author: ydzhao
# @Description: read configuration file's parameters
# @Date: 2020-07-27 13:10:31
# @LastEditTime: 2020-09-13 07:33:14
# @FilePath: /git/code/ydzhao/spark-hive-rdbms/conf/ini.sh
###
#!/usr/bin/bash
INIFILE=$1
SECTION=$2
ITEM=$3
SPLIT=$4
NEWVAL=$5
function ReadINIfile(){
ReadINI=`awk -F $SPLIT '/\['$SECTION'\]/{a=1}a==1&&$1~/'$ITEM'/{print $2;exit}' $INIFILE`
echo $ReadINI
}
function WriteINIfile(){
WriteINI=`sed -i "/^\[$SECTION\]/,/^\[/ {/^\[$SECTION\]/b;/^\[/b;s/^$ITEM*=.*/$ITEM=$NEWVAL/g;}" $INIFILE`
echo $WriteINI
}
if [ "$5" = "" ] ;then
ReadINIfile $1 $2 $3 $4
else
WriteINIfile $1 $2 $3 $4 $5
fi
# ./ini.sh $1 $2 $3 $4 读取ini
# ./ini.sh $1 $2 $3 $4 $5 写入ini "newval"
我们读取上面的config.ini
文件
可以这样写
[root@ythbdnode01 ~]# mysql_conn_info_user=`bash /root/ini_pro.sh /root/config.ini dbms_parameters mysql_conn_info_user =`
[root@ythbdnode01 ~]# echo $mysql_conn_info_user
test_db
这里传入4个变量分别是:
- INIFILE 配置文件名
- SECTION 某个区块
- ITEM 某个区块下的明细
- SPLIT 分隔符
[root@ythbdnode01 ~]# mysql_conn_info_passwd=`bash /root/ini_pro.sh /root/config.ini dbms_parameters mysql_conn_info_passwd @@@`
[root@ythbdnode01 ~]# echo $mysql_conn_info_passwd
=4bI=prodha