2、mllib逻辑回归实战-数据预处理

需求:根据提供的数据预测新生儿的存活率

背景介绍

美国CDC公布的新生儿存活率

https://www.cdc.gov/nchs/data_access/vitalstatsonline.html

数据集介绍:

以下为部分数据的结果:

INFANT_ALIVE_AT_REPORT BIRTH_YEAR BIRTH_MONTH BIRTH_PLACE MOTHER_AGE_YEARS MOTHER_RACE_6CODE MOTHER_EDUCATION FATHER_COMBINED_AGE FATHER_EDUCATION
N 2015 2 1 29 3 9 99 9
N 2015 2 1 22 1 3 29 4
N 2015 2 1 38 1 4 40 3
N 2015 4 1 39 2 7 42 6
N 2015 4 1 18 3 2 99 9
N 2015 4 1 32 1 4 37 4
N 2015 5 1 22 3 3 25 2
N 2015 6 1 25 1 5 26 6
N 2015 6 1 26 1 7 32 6
N 2015 6 1 39 4 4 66 4
N 2015 7 1 25 3 4 22 3

处理流程:

  • 数据清洗
  • 特征选择
  • 建立模型
  • 模型评价
  • 模型应用

预处理流程实现

  • 以DataFrame加载数据
import pyspark.sql.types as typ
data_file = 'file:///root/bigdata/data/births_train.csv'
labels = [
    ('INFANT_ALIVE_AT_REPORT', typ.StringType()),
    ('BIRTH_YEAR', typ.IntegerType()),
    ('BIRTH_MONTH', typ.IntegerType()),
    ('BIRTH_PLACE', typ.StringType()),
    ('MOTHER_AGE_YEARS', typ.IntegerType()),
    ('MOTHER_RACE_6CODE', typ.StringType()),
    ('MOTHER_EDUCATION', typ.StringType()),
    ('FATHER_COMBINED_AGE', typ.IntegerType()),
    ('FATHER_EDUCATION', typ.StringType()),
    ('MONTH_PRECARE_RECODE', typ.StringType()),
    ('CIG_BEFORE', typ.IntegerType()),
    ('CIG_1_TRI', typ.IntegerType()),
    ('CIG_2_TRI', typ.IntegerType()),
    ('CIG_3_TRI', typ.IntegerType()),
    ('MOTHER_HEIGHT_IN', typ.IntegerType()),
    ('MOTHER_BMI_RECODE', typ.IntegerType()),
    ('MOTHER_PRE_WEIGHT', typ.IntegerType()),
    ('MOTHER_DELIVERY_WEIGHT', typ.IntegerType()),
    ('MOTHER_WEIGHT_GAIN', typ.IntegerType()),
    ('DIABETES_PRE', typ.StringType()),
    ('DIABETES_GEST', typ.StringType()),
    ('HYP_TENS_PRE', typ.StringType()),
    ('HYP_TENS_GEST', typ.StringType()),
    ('PREV_BIRTH_PRETERM', typ.StringType()),
    ('NO_RISK', typ.StringType()),
    ('NO_INFECTIONS_REPORTED', typ.StringType()),
    ('LABOR_IND', typ.StringType()),
    ('LABOR_AUGM', typ.StringType()),
    ('STEROIDS', typ.StringType()),
    ('ANTIBIOTICS', typ.StringType()),
    ('ANESTHESIA', typ.StringType()),
    ('DELIV_METHOD_RECODE_COMB', typ.StringType()),
    ('ATTENDANT_BIRTH', typ.StringType()),
    ('APGAR_5', typ.IntegerType()),
    ('APGAR_5_RECODE', typ.StringType()),
    ('APGAR_10', typ.IntegerType()),
    ('APGAR_10_RECODE', typ.StringType()),
    ('INFANT_SEX', typ.StringType()),
    ('OBSTETRIC_GESTATION_WEEKS', typ.IntegerType()),
    ('INFANT_WEIGHT_GRAMS', typ.IntegerType()),
    ('INFANT_ASSIST_VENTI', typ.StringType()),
    ('INFANT_ASSIST_VENTI_6HRS', typ.StringType()),
    ('INFANT_NICU_ADMISSION', typ.StringType()),
    ('INFANT_SURFACANT', typ.StringType()),
    ('INFANT_ANTIBIOTICS', typ.StringType()),
    ('INFANT_SEIZURES', typ.StringType()),
    ('INFANT_NO_ABNORMALITIES', typ.StringType()),
    ('INFANT_ANCEPHALY', typ.StringType()),
    ('INFANT_MENINGOMYELOCELE', typ.StringType()),
    ('INFANT_LIMB_REDUCTION', typ.StringType()),
    ('INFANT_DOWN_SYNDROME', typ.StringType()),
    ('INFANT_SUSPECTED_CHROMOSOMAL_DISORDER', typ.StringType()),
    ('INFANT_NO_CONGENITAL_ANOMALIES_CHECKED', typ.StringType()),
    ('INFANT_BREASTFED', typ.StringType())
]

schema = typ.StructType([
        typ.StructField(e[0], e[1], False) for e in labels
                        ])

births = spark.read.csv(data_file,header=True,schema=schema)
  • 选择部分列的数据
'''
婴儿的生存情况,关注父母的信息
'''

selected_features = [
    'INFANT_ALIVE_AT_REPORT',
    'BIRTH_PLACE',
    'MOTHER_AGE_YEARS',
    'FATHER_COMBINED_AGE',
    'CIG_BEFORE',
    'CIG_1_TRI',
    'CIG_2_TRI',
    'CIG_3_TRI',
    'MOTHER_HEIGHT_IN',
    'MOTHER_PRE_WEIGHT',
    'MOTHER_DELIVERY_WEIGHT',
    'MOTHER_WEIGHT_GAIN',
    'DIABETES_PRE',
    'DIABETES_GEST',
    'HYP_TENS_PRE',
    'HYP_TENS_GEST',
    'PREV_BIRTH_PRETERM'
]

births_trimmed = births.select(selected_features)
  • 数据清洗和数据转换
'''
对于孕期吸烟量的处理
0:不吸烟
1-98:实际数量
99:未知
把未知当做不吸烟来处理
'''
import pyspark.sql.functions as func


def correct_cig(feat):
    return func \
        .when(func.col(feat) != 99, func.col(feat))\
        .otherwise(0)

'''
withColumn:第一个参数列名,第二个参数:函数
'''

births_transformed = births_trimmed \
    .withColumn('CIG_BEFORE', correct_cig('CIG_BEFORE'))\
    .withColumn('CIG_1_TRI', correct_cig('CIG_1_TRI'))\
    .withColumn('CIG_2_TRI', correct_cig('CIG_2_TRI'))\
    .withColumn('CIG_3_TRI', correct_cig('CIG_3_TRI'))
  • 对目标值进行重新编码
'''
筛选所有值为Y|N|U编码的列
'''

cols = [(col.name, col.dataType) for col in births_trimmed.schema]

YNU_cols = []

for i, s in enumerate(cols):
    if s[1] == typ.StringType():
        dis = births.select(s[0]) \
            .distinct() \
            .rdd \
            .map(lambda row: row[0]) \
            .collect()
        print(dis)
        if 'Y' in dis:
            YNU_cols.append(s[0])

'''
对于要预测的目标变量重新编码,封装成udf,返回类型integer
'''
'''
udf:把用户方法封装成一个可以在DataFrame中使用的方法,参数1:用户方法,参数2:返回值类型
这个用户方法的第一个参数一定是一个Column类型,是有sparksql自动传入的
'''
recode_dictionary = {
    'YNU': {
        'Y': 1,
        'N': 0,
        'U': 0
    }
}
def recode(key,col):
    return recode_dictionary[key][col]

rec_integer = func.udf(recode, typ.IntegerType())

# births.select([
#         'INFANT_NICU_ADMISSION',
#         rec_integer(
#             'INFANT_NICU_ADMISSION', func.lit('YNU')
#         ).alias('INFANT_NICU_ADMISSION_RECODE')]
#      ).take(5)


'''
根据列的类型,构造select语句的表达式部分
'''
#func.lit('YNU'):返回Column对象,对应整个DataFrame的每一行数据
#rec_integer(func.lit('YNU'),x):相当于将x对应的列的数据每一个的值都按照recode_dictionary的方式对应的值做修改
exprs_YNU = [
    rec_integer(func.lit('YNU'),x).alias(x)
    if x in YNU_cols
    else x
    for x in births_transformed.columns
]

'''
>>> exprs_YNU
[Column<recode(YNU,INFANT_ALIVE_AT_REPORT) AS `INFANT_ALIVE_AT_REPORT`>, 'BIRTH
_PLACE', 'MOTHER_AGE_YEARS', 'FATHER_COMBINED_AGE', 'CIG_BEFORE', 'CIG_1_TRI', '
CIG_2_TRI', 'CIG_3_TRI', 'MOTHER_HEIGHT_IN', 'MOTHER_PRE_WEIGHT', 'MOTHER_DELIVE
RY_WEIGHT', 'MOTHER_WEIGHT_GAIN', Column<recode(DIABETES_PRE, YNU) AS `DIABETES_
PRE`>, Column<recode(DIABETES_GEST, YNU) AS `DIABETES_GEST`>, Column<recode(HYP_
TENS_PRE, YNU) AS `HYP_TENS_PRE`>, Column<recode(HYP_TENS_GEST, YNU) AS `HYP_TEN
S_GEST`>, Column<recode(PREV_BIRTH_PRETERM, YNU) AS `PREV_BIRTH_PRETERM`>]

'''


'''
最后的干净的数据集
'''
births_transformed = births_transformed.select(exprs_YNU)

results matching ""

    No results matching ""