使用LSTM神经网络预测股价涨跌
Contents
将股价未来的表现分为涨、平、跌三类,可利用过去一段时间内的数据来预测未来一段时间(如未来5天)的股价涨跌。利用的数据包含技术面K线数据和基本面财务数据。此处选用LSTM网络结构。
数据获取
使用baostock可以获取个股的历史K线数据、季频财务数据、季频公司报告数据和宏观经济数据。由于季频公司报告数据非强制披露,宏观经济数据也缺失较多,因此仅使用前两者进行预测。
采用baostock来获取股票的历史K线数据和财务数据,使用方法可参考官网:
http://baostock.com/baostock/index.php/%E9%A6%96%E9%A1%B5
获取K线数据
使用query_history_k_data_plus()方法获取历史K线数据,可以通过参数设置获取日k线、周k线、月k线,以及5分钟、15分钟、30分钟和60分钟k线数据。能获取1990-12-19至当前时间的数据。
import baostock as bs
bs.login() #登录
k_df=bs.query_history_k_data_plus('sz.000001',adjustflag="2",fields='date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,peTTM,pbMRQ,psTTM,pcfNcfTTM,isST').get_data()#.tail(50)
k_df

获取财务数据
baostock为查询财务数据提供了6个方法,分别从盈利、营运、成长、偿债、现金流量和杜邦指数六个角度反映公司的财务状况,查询时以季度为单位查询。可以通过参数设置获取对应年份、季度数据。baostock提供2007年至今数据。方法如下:
季频盈利能力:query_profit_data()
季频营运能力:query_operation_data()
季频成长能力:query_growth_data()
季频偿债能力:query_balance_data()
季频现金流量:query_cash_flow_data()
季频杜邦指数:query_dupont_data()
几个方法的参数都相同,均为代码、年份、季度:
code:股票代码,sh或sz.+6位数字代码,或者指数代码,如:sh.601398。sh:上海;sz:深圳。此参数不可为空;
year:统计年份,为空时默认当前年;
quarter:统计季度,可为空,默认当前季度。不为空时只有4个取值:1,2,3,4。
返回值有所不同。
将上述各项财务数据合并起来,得到较大的数据框:
bs.login()
def get_FS_df(code,year,quarter):
profit_df=bs.query_profit_data(code,year,quarter).get_data()
operation_df=bs.query_operation_data(code,year,quarter).get_data()
growth_df=bs.query_growth_data(code,year,quarter).get_data()
balance_df=bs.query_balance_data(code,year,quarter).get_data()
cash_flow_df=bs.query_cash_flow_data(code,year,quarter).get_data()
dupont_df=bs.query_dupont_data(code,year,quarter).get_data()
FS_df=pd.concat([profit_df,operation_df,growth_df,balance_df,cash_flow_df,dupont_df],axis=1)
FS_df=FS_df.loc[:,~FS_df.columns.duplicated()]#列去重
return FS_df
FS_df=get_FS_df(code="sz.000001", year=2007, quarter=1)
FS_df

数据保存
为了后面更自由地使用数据,我们将需要长时间获取的数据尽可能多地保存在本地,比如K线数据和季频财务数据。
K线数据:将所有能查询到的公司的K线数据存放在一张表中,表头包括股票代码、日期及其他各种指标,以日为单位;
季频财务数据:将所有能查询到的公司的季频财务数据存放在一张表中,表头包括股票代码、统计日期(季度)及其他各种指标,时间可以从2007年第1季度起,以季度为单位;
保存K线数据
首先使用query_stock_industry()方法获取所有股票代码:
from tqdm import tqdm import pandas as pd import baostock as bs bs.login() stock_industry_df=bs.query_stock_industry().get_data() code_list=list(stock_industry_df.code) #code_list stock_industry_df
接下来获取每只股票的K线数据。季频财务数据只能获取2007年以后的,K线数据这里从2010年开始。试了下前10个:
bs.login()
all_k_df_list=[]
for code in tqdm(code_list[:10]):
k_df=bs.query_history_k_data_plus(code,adjustflag="2",start_date='2010-01-01',fields='date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,peTTM,pbMRQ,psTTM,pcfNcfTTM,isST').get_data()
all_k_df_list.append(k_df)
all_k_df=pd.concat(all_k_df_list)
all_k_df.to_feather('data/test.feather')
一只股票要花超过1s钟的时间,5000多只要花两个小时,无法忍受。baostock不支持多线程,因此这里采用多进程的方法。jupyter中无法直接运行多进程程序(会卡住),需要将程序写成py文件,再用run方法运行。且用run方法运行多进程py文件时时,子进程函数输出语句失效,因此这里将程序写入py文件,使用命令行执行py文件(中间有报错,需要将numpy降级到1.X,如1.26.4),并将数据框拼接后保存成体积较少且易于读取的feather格式:
%%writefile multiprocessing_get_k_data.py
#在jupyter notebook中运行时,需将代码写入py文件,再在notebook中运行py文件
from multiprocessing import Pool
import time
import baostock as bs
import pandas as pd
def print_error(value): #当进程函数报错时,该函数能输出错误,但不能指示出错误位置
print("error: ", value)
def mycallback(x): #该函数将子进程的处理结果添加到总的结果列表中
total_result_list.append(x)
def operation_fun(num,num_list): #子进程操作函数
print('\r%d/%d:%s'%(num_list.index(num)+1,len(num_list),num)) #进度
#time.sleep(1)
bs.login()#这里也需要登录
k_df=bs.query_history_k_data_plus(num,adjustflag="2",start_date='2010-01-01',fields='date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,peTTM,pbMRQ,psTTM,pcfNcfTTM,isST').get_data()
return k_df
if __name__ == '__main__':
__spec__ = "ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>)"#添加此行可在jupyter notebook中重复运行
#num_list=list(range(100)) #待处理的数据列表
bs.login()
stock_industry_df=bs.query_stock_industry().get_data()
code_list=list(stock_industry_df.code)
total_result_list=[]
process_num=20 #设置进程数
pool = Pool(process_num)
start_time=time.time()
for code in code_list[:]:
pool.apply_async(operation_fun, (code,code_list), callback=mycallback,error_callback=print_error)
pool.close()
pool.join()
end_time=time.time()
print('%d进程处理%d个数,耗时%.2fs'%(process_num,len(code_list),end_time-start_time))
print(len(total_result_list))
all_k_df=pd.concat(total_result_list)
all_k_df.reset_index(drop=True,inplace=True)
all_k_df.to_feather('data/k_2010-2025.feather')
使用命令行执行上述py文件即可。
保存季频财务数据
运营同样的方法获取季频财务数据:
%%writefile multiprocessing_get_FS_data.py
#在jupyter notebook中运行时,需将代码写入py文件,再在notebook中运行py文件
from multiprocessing import Pool
import time
import baostock as bs
import pandas as pd
def print_error(value): #当进程函数报错时,该函数能输出错误,但不能指示出错误位置
print("error: ", value)
def mycallback(x): #该函数将子进程的处理结果添加到总的结果列表中
total_result_list.append(x)
def operation_fun(num,num_list): #子进程操作函数
print('\r%d/%d:%s'%(num_list.index(num)+1,len(num_list),num)) #进度
#time.sleep(1)
bs.login()#这里也需要登录
FS_df_list=[]
for year in range(2007,2026):
for quarter in range(1,5):
FS_df=get_FS_df(num,year,quarter)
FS_df_list.append(FS_df)
merged_FS_df=pd.concat(FS_df_list)
return merged_FS_df
def get_FS_df(code,year,quarter):
profit_df=bs.query_profit_data(code,year,quarter).get_data()
operation_df=bs.query_operation_data(code,year,quarter).get_data()
growth_df=bs.query_growth_data(code,year,quarter).get_data()
balance_df=bs.query_balance_data(code,year,quarter).get_data()
cash_flow_df=bs.query_cash_flow_data(code,year,quarter).get_data()
dupont_df=bs.query_dupont_data(code,year,quarter).get_data()
FS_df=pd.concat([profit_df,operation_df,growth_df,balance_df,cash_flow_df,dupont_df],axis=1)
FS_df=FS_df.loc[:,~FS_df.columns.duplicated()]#列去重
return FS_df
if __name__ == '__main__':
__spec__ = "ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>)"#添加此行可在jupyter notebook中重复运行
#num_list=list(range(100)) #待处理的数据列表
bs.login()
stock_industry_df=bs.query_stock_industry().get_data()
code_list=list(stock_industry_df.code)
total_result_list=[]
process_num=20 #设置进程数
pool = Pool(process_num)
start_time=time.time()
for code in code_list[:]:
pool.apply_async(operation_fun, (code,code_list), callback=mycallback,error_callback=print_error)
pool.close()
pool.join()
end_time=time.time()
print('%d进程处理%d个数,耗时%.2fs'%(process_num,len(code_list),end_time-start_time))
print(len(total_result_list))
all_FS_df=pd.concat(total_result_list)
all_FS_df.reset_index(drop=True,inplace=True)
all_FS_df.to_feather('data/FS_2007-2025.feather')
数据清洗及样本构建
模型的输入应包含K线数据、季频财务数据,这里使用最近D天的K线数据、最近Q个季度的财务数据来预测股价在未来F天的涨跌情况。
此处使用K线数据构建样本标签。取连续的(D+F)天的K线数据,若最后F天内的平均收盘价格(此处也可使用其他价格,或其平均价格)高于当前收盘价格(或者自定义其他价格)P%,则标记为涨;若低于当前价格P%,则标记为跌;否则标记为平;
目标是构建N个样本,可在1000多万行的K线数据框中随机选取N行,取这些行股票之前的D天K线数据和之后的F天K线数据计算样本标签,若数据长度不满足则舍弃。
取样会是一个漫长的过程,因此需要将生成的样本固定下来,以便后面直接读取。由于K线数据和财务数据具有不同的维度,因此分别保存为一个文件。
import pandas as pd import baostock as bs import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm import datetime
K线数据清洗
读取k线数据:
all_k_df=pd.read_feather('data/k_2010-2025.feather')
all_k_df.reset_index(drop=True,inplace=True) #重置行索引
all_k_df
查看缺失值个数:
all_k_df.isna().sum()

数据无缺失值,但有空值,将空值替换为缺失值再统计:
all_k_df.replace('',np.nan,inplace=True) #空值替换为缺失值以便进行筛选
all_k_df.isna().sum()

不同字段缺失值的数量不同。应根据各字段的含义采用不同的策略进行填充。这里查看各字段的含义:

查看各字段数据类型:
all_k_df.dtypes
各字段均为object类型,为便于处理,这里将数字列转换为float类型:
for c in tqdm(all_k_df.columns[2:]):
all_k_df[c]=all_k_df[c].astype('float64')
all_k_df.dtypes
对于停牌日(tradestatus=0)的股票,这里将volume、amount、turn、pctChg填充为0,而peTTM、pbMRQ、psTTM、pcfNcfTTM 则可以使用该股票最近一次的数据填充。
all_k_df.loc[all_k_df['tradestatus']==0,['volume','amount','turn','pctChg']]=[0,0,0,0] all_k_df.isna().sum()

对于余下未填充的缺失值,使用前向填充和后向填充的方法。
如果将含有缺失值的股票的k线数据单独分离出来再进行填充,填充后再拼接在一起,像下面这样,那么速度会巨慢:
# 获取含有缺失值的股票列表
lack_k_code_list=all_k_df[all_k_df.isna().sum(axis=1)>=1]['code'].unique()
print(len(lack_k_code_list))#这里5000多只股票均有缺失值
"""
filled_k_df_list=[]
for code in tqdm(lack_k_code_list):
k_df=all_k_df[all_k_df['code']==code]
k_df.ffill(inplace=True)
k_df.ffill(inplace=False)
filled_k_df_list.append(k_df)
filled_all_k_df=pd.concat([all_k_df[~all_k_df.code.isin(lack_k_code_list)],pd.concat(filled_k_df_list)])
filled_all_k_df.isna().sum()
"""
这里改变一下策略,缺失值逐个填充:
for c in tqdm(['volume','amount','turn','pctChg','peTTM','pbMRQ','psTTM','pcfNcfTTM']):
na_index_list=all_k_df[all_k_df[c].isna()].index
for i in na_index_list:#向前填充,若首行为空则会遗漏
if i and all_k_df.loc[i,'code']==all_k_df.loc[i-1,'code']:
all_k_df.loc[i,c]=all_k_df.loc[i-1,c]
na_index_list=all_k_df[all_k_df[c].isna()].index
for i in na_index_list[::-1]:#向后填充
if i!=len(all_k_df) and all_k_df.loc[i,'code']==all_k_df.loc[i+1,'code']:
all_k_df.loc[i,c]=all_k_df.loc[i+1,c]
all_k_df.isna().sum()

当前已填充完成。所有看adjustflag均相同,该特征对于分类无意义,可以去掉。
filled_k_df=all_k_df.drop('adjustflag',axis=1)
季频财务数据清洗
all_FS_df=pd.read_feather('data/FS_2007-2025.feather')
同样,将空值替换为缺失值再统计并排序:
all_FS_df.replace('',np.nan,inplace=True) #空值替换为缺失值以便进行筛选
all_FS_df.isna().sum().sort_values()
查看缺失值所占百分比:
na_percent_series=all_FS_df.isna().sum()/len(all_FS_df) na_percent_series.sort_values()

可以看到ebitToInterest(已或利息倍数)、MBRevenue(主营营业收入)缺失较多,填充意义不大,后面不考虑此两项。
接下来对all_FS_df进行处理。
对于每只股票,都需要将其季频财务数据整理为连续的。先根据statDate列提取出年份和月份。
先看下统计日期,仅有4种,都在每季度的最后一天:
set(map(lambda s:s[5:],all_FS_df['statDate']))
{'03-31', '06-30', '09-30', '12-31'}
接下来提取出年份并根据统计日期确定季度:
if 'year' not in all_FS_df.columns:
all_FS_df.insert(3,'year',list(map(lambda s:int(s[:4]),all_FS_df['statDate'])))
if 'quarter' not in all_FS_df.columns:
all_FS_df.insert(4,'quarter',list(map(lambda s:1 if s[5:7]=='03' else 2 if s[5:7]=='06' else 3 if s[5:7]=='09' else 4,all_FS_df['statDate'])))
查看每只股票的季频财务数据是否连续,打印出季频财务数据有完全缺失的股票,根据前后的时间差来计算有无缺失:
lack_FS_code_list=[]
for code in tqdm(all_FS_df.code.unique()):
code_df=all_FS_df[all_FS_df.code==code]
quarters=(code_df.year.iloc[-1]-code_df.year.iloc[0])*4+(code_df.quarter.iloc[-1]-code_df.quarter.iloc[0])+1
if len(code_df)<quarters:
lack_FS_code_list.append(code)
print(len(lack_FS_code_list))
print(lack_FS_code_list[:5])
接下来对缺失的季度数据进行填充,这里直接采用上一季度的数据:
filled_FS_df_list=[]
for code in tqdm(lack_FS_code_list):
code_df=all_FS_df[all_FS_df.code==code]
i=1
while i<len(code_df):
last_date=datetime.datetime.strptime(code_df.iloc[i-1]['statDate'],'%Y-%m-%d')
this_date=datetime.datetime.strptime(code_df.iloc[i]['statDate'],'%Y-%m-%d')
if (this_date-last_date).days>92:
last_quarter=code_df.iloc[i-1]['quarter']
last_year=code_df.iloc[i-1]['year']
added_row_df=code_df.iloc[[i-1]].copy()
added_row_df['year']=last_year if last_quarter<4 else last_year+1
added_row_df['quarter']=last_quarter+1 if last_quarter<4 else 1
date_str=str(last_year)+('-03-31' if last_quarter==4 else '-06-30' if last_quarter==1 else '-09-30' if last_quarter==2 else '-12-31')
added_row_df['pubDate']=date_str
added_row_df['statDate']=date_str
code_df=pd.concat([code_df.iloc[:i],added_row_df,code_df.iloc[i:]])
break
i+=1
filled_FS_df_list.append(code_df)
all_FS_df=pd.concat([all_FS_df[~all_FS_df.code.isin(lack_FS_code_list)],pd.concat(filled_FS_df_list)])
all_FS_df
移除缺失值较多的列:
all_FS_df.drop(columns=['ebitToInterest','MBRevenue'],inplace=True) all_FS_df
接下来对缺失值进行填充。这里需要取出每只股票,单独填充,填充后再拼接起来。对于每只股票,数据按照时间顺序排列,因此对于缺失的数据,采用之前的数据向后填充;如果之前没有数据,则用后面的数据向前填充:
filled_FS_df_list=[]
for code in tqdm(all_FS_df.code.unique()):
code_df=all_FS_df[all_FS_df.code==code]
if code_df.isna().sum().sum():
code_df.ffill(inplace=True)
code_df.bfill(inplace=True)
filled_FS_df_list.append(code_df)
filled_FS_df=pd.concat(filled_FS_df_list)
filled_FS_df
还有一部分股票所有季度的某个数据完全缺失,这里使用平均值进行填充。先查看下数据类型:
filled_FS_df.dtypes
涉及财务数据的特征数据类型为object,直接获取其平均值会失败,这里先转换下数据类型,填充后再替换原来位置的数据:
feature_df=filled_FS_df.loc[:,'roeAvg':].astype('float64')
feature_df.fillna(feature_df.mean(),inplace=True)
filled_FS_df.loc[:,'roeAvg':]=feature_df
filled_FS_df
样本构建
现在通过随机采样的方式从清洗后的数据框中选取时间序列片段构建样本:
K线数据的维度:(样本数N,时间序列长度(D+F),K线数据特征数)
季频财务数据维度:(样本数N,历史季频财务数据长度Q,季频财务数据特征数)
重新清洗季频财务数据:
filled_FS_df.reset_index(inplace=True,drop=True) filled_FS_df
下面构造训练数据。根据将股票在未来F天内的平均价格与当日价格相比,涨幅高于C的标记为3,在-C到C之间的标记为2,低于-C的标记为1。这里open、high、low和close都计算一下,后面方便灵活改变。先尝试200个采样,观察结果:
N=20000
D=100
F=5
Q=12
C=0.1#变化0.1,即10%
sample_idx_list=random.sample(range(D+F,len(filled_k_df)),N)
sample_k_df_list=[]
sample_FS_df_list=[]
label_2d_list=[]#分别以open,high,low,close计算有四个标签
for i in tqdm(sample_idx_list[:200]):
sample_k_df=filled_k_df[i-(D+F)+1:i+1]
last_date=datetime.datetime.strptime(sample_k_df['date'].iloc[-1],'%Y-%m-%d')
first_date=datetime.datetime.strptime(sample_k_df['date'].iloc[0],'%Y-%m-%d')
if len(sample_k_df['code'].unique())==1 and (last_date-first_date).days<=1.6*D:#100个交易日,前后的时间差在150天左右,超过160天的可能会有较多停牌日,此类过滤掉
# 查找预测当天最近的FS_past个季频财务报告,获取不全的也舍弃掉
predict_date=sample_k_df['date'].iloc[-(F+1)]
code=sample_k_df['code'].iloc[-1]
past_FS_df=filled_FS_df[(filled_FS_df.code==code) & (filled_FS_df.pubDate<=predict_date)]
if len(past_FS_df)>=Q:
sample_k_df_list.append(sample_k_df.loc[:,'open':])#仅保存数字特征值
sample_FS_df_list.append(past_FS_df[-Q:].loc[:,'roeAvg':])
future_average_price_array=sample_k_df.iloc[-5:][['open','high','low','close']].mean().to_numpy()
now_price_array=sample_k_df.iloc[[-6]][['open','high','low','close']].mean().to_numpy()#此处可以选当前价格或过去几天的平均价格作为计算基准
change_percent_array=future_average_price_array/now_price_array-1
label_list=list(map(lambda x:3 if x>=C else 2 if x>=-C else 1,change_percent_array))
label_2d_list.append(label_list)
观察收盘价每类标签的数量:
np.unique(np.array(label_2d_list)[:,3],return_counts=True) (array([1, 2]), array([ 3, 134]))
可以看到,准备200个采样,保留的仅有一半,不同标签数量差异大,样本极不均衡,需要更改策略。
可在整个K线数据框中,对几乎所有行先打上标签,然后再从每类标签中取样。首先插入4列标签列:
if 'open_label' not in filled_k_df.columns:
filled_k_df.insert(2,'open_label',np.nan)
if 'high_label' not in filled_k_df.columns:
filled_k_df.insert(3,'high_label',np.nan)
if 'low_label' not in filled_k_df.columns:
filled_k_df.insert(4,'low_label',np.nan)
if 'close_label' not in filled_k_df.columns:
filled_k_df.insert(5,'close_label',np.nan)
接下来要填充这四列。1000多万行,如果使用如下逐行计算的方法,速度慢到难以想象:
"""
for i in tqdm(range(len(filled_k_df)-F)):
if filled_k_df.iloc[i]['code']==filled_k_df.iloc[i+F]['code']:
now_price_array=np.array(filled_k_df[i:i+1][['open','high','low','close']].mean())
future_mean_price_array=np.array(filled_k_df[i+1:i+6][['open','high','low','close']].mean())
label_list=list(map(lambda x:3 if x>=C else 2 if x>=-C else 1,future_mean_price_array/now_price_array-1))
filled_k_df.loc[i,['open_label','high_label','low_label','close_label']]=label_list
filled_k_df
"""
因此这里使用rolling方法,计算出四个价格列的移动平均值,再将未来几天的移动平均价格与当前价格比较,计算出标签。这里不用考虑不同股票的价格参与了同一平均值的计算,边缘处的计算结果后面取样时会被过滤掉:
for c in tqdm(['open_label','high_label','low_label','close_label'][:]):
s=pd.concat([filled_k_df[c[:-6]].rolling(F).mean()[F:],pd.Series([np.nan]*F)])#补全长度
change_array=np.array(s)/filled_k_df[c[:-6]]
filled_k_df[c]=(change_array>=(1-C)).astype('int')-(change_array<=(1+C)).astype('int')+2
filled_k_df

好了,查看下四列的各标签数量:
print(filled_k_df['open_label'].value_counts()) print(filled_k_df['high_label'].value_counts()) print(filled_k_df['low_label'].value_counts()) print(filled_k_df['close_label'].value_counts()) open_label 1 13087572 Name: count, dtype: int64 high_label 1 13087572 Name: count, dtype: int64 low_label 2 12514806 3 325819 1 246947 Name: count, dtype: int64 close_label 2 12469206 3 375248 1 243118 Name: count, dtype: int64
从close_label列中三类标签分别选取10000个:
sample_k_df_list=[]
sample_FS_df_list=[]
label_2d_list=[]
for i in [1,2,3]:
index_list=list(filled_k_df[filled_k_df['close_label']==i].index)
count=0
while count<10000 and len(index_list)>0:
print('\r',count,end='')
sample_index=random.sample(index_list,1)[0]
index_list.remove(sample_index)
if sample_index<D-1:
continue
sample_k_df=filled_k_df[sample_index-D+1:sample_index+1]
last_date=datetime.datetime.strptime(sample_k_df['date'].iloc[-1],'%Y-%m-%d')
first_date=datetime.datetime.strptime(sample_k_df['date'].iloc[0],'%Y-%m-%d')
if len(sample_k_df['code'].unique())==1 and (last_date-first_date).days<=1.6*D:#100个交易日,前后的时间差在150天左右,超过160天的可能会有较多停牌日,此类过滤掉
predict_date=sample_k_df['date'].iloc[-1]
code=sample_k_df['code'].iloc[-1]
past_FS_df=filled_FS_df[(filled_FS_df.code==code) & (filled_FS_df.pubDate<=predict_date)]
if len(past_FS_df)>=Q:
sample_k_df_list.append(sample_k_df.loc[:,'open':])#仅保存数字特征值
sample_FS_df_list.append(past_FS_df[-Q:].loc[:,'roeAvg':])
label_2d_list.append(np.array(sample_k_df.loc[sample_index,['open_label','high_label','low_label','close_label']]))
count+=1
#label_2d_list
print('Finished')
以上采样的数据转换为numpy数组,并查看其形状:
k_array=np.array(sample_k_df_list)
FS_array=np.array(sample_FS_df_list)
label_array=np.array(label_2d_list).astype('int')
k_array.shape,FS_array.shape,label_array.shape
((30000, 100, 15), (30000, 12, 38), (30000, 4))
保存数据,这里k线数据仅保存前D行:
k_array[:,:D,:].tofile(f'data/k_{k_array.shape[0]}x{k_array.shape[1]}x{k_array.shape[2]}.dat')
FS_array=FS_array.astype(float)
FS_array.tofile(f'data/FS_{FS_array.shape[0]}x{FS_array.shape[1]}x{FS_array.shape[2]}.dat')
label_array.tofile(f'data/label_{label_array.shape[0]}x{label_array.shape[1]}.dat')




