Shuffle和划分
下文以一个异常检测数据集(正负样本不平衡)为例,在生成第一批TFRecord时,我将正负样本分别写入单独的TFrecord文件以备后续在对正负样本有不同处理策略的情况下无需再解析example_proto。比如在以下代码中,我对正负样本有不同的验证集比例,并将他们写入不同的验证集文件。
import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm as tqdm
# TFRecord划分
raw_normal_dataset = tf.data.TFRecordDataset("normal_16_256.tfrecords","GZIP"
raw_anomaly_dataset = tf.data.TFRecordDataset("anomaly_16_256.tfrecords","GZIP"
normal_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP"
anomaly_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP"
train_writer_list = [tf.io.TFRecordWriter(r'ex_1/'+'train_16_256_{}.tfrecords'.format(i,"GZIP" for i in range(SUBFILE_NUM+1]
with tqdm(total=LEN_NORMAL_DATASET+LEN_ANOMALY_DATASET as pbar:
for example_proto in raw_normal_dataset:
# 划分训练集和测试集
if np.random.random( > 0.99: # 正样本测试集的比例
normal_val_writer.write(example_proto.numpy(
else:
train_writer_list[np.random.randint(0,SUBFILE_NUM+1].write(example_proto.numpy(
pbar.update(1
for example_proto in raw_anomaly_dataset:
# 划分训练集和测试集
if np.random.random( > 0.7: # 负样本测试集的比例
anomaly_val_writer.write(example_proto.numpy(
else:
train_writer_list[np.random.randint(0,SUBFILE_NUM+1].write(example_proto.numpy(
pbar.update(1
normal_val_writer.close(
anomaly_val_writer.close(
for train_writer in train_writer_list:
train_writer.close(
读取
raw_train_dataset = tf.data.TFRecordDataset([r'ex_1/'+'train_16_256_{}.tfrecords'.format(i for i in range(SUBFILE_NUM+1],"GZIP"
raw_train_dataset = raw_train_dataset.shuffle(buffer_size=100000.batch(BATCH_SIZE
parsed_train_dataset = raw_train_dataset.map(map_func=map_func
raw_normal_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP"
raw_anomaly_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP"
parsed_nomarl_val_dataset = raw_normal_val_dataset.batch(BATCH_SIZE.map(map_func=map_func
parsed_anomaly_val_dateset = raw_anomaly_val_dataset.batch(BATCH_SIZE.map(map_func=map_func