当LSTMPredictStock/config.json内的epochs大于1时出现的bug
我如何发现这个问题
在详细阅读完README后,我开始自己进行模型的训练。在注意到提供的数据对于每一支股票都只训练了1个epoch且loss大概为0.0020-0.1,自然而然我就想把epoch调高。通过在LSTMPredictStock/config.json内的training - epochs调成2,程序报错:
[Model] Training Started
[Model] 2 epochs, 8 batch size, 29 batches per epoch
2025-03-09 22:12:11.276718: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2025-03-09 22:12:11.276967: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 1996205000 Hz
Epoch 1/2
29/29 [==============================] - 2s 30ms/step - loss: 0.0024
Epoch 2/2
2025-03-09 22:12:13.764939: W tensorflow/core/framework/op_kernel.cc:1751] Invalid argument: TypeError: `generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was [array([[ 0. , 0. , 0. , 0. ],
[ 0.00276253, -0.00033106, 0.00230376, 0.00154365],
[ 0.00455639, 0.00548025, 0.00636496, 0.00553156],
[ 0.00534874, 0.00714018, 0.00698772, 0.00492981],
[ 0.00480491, 0.00315387, 0.00463258, 0.00367912],
[-0.00418113, -0.00538486, -0.00376593, -0.00524791],
[ 0.00094419, 0.00176553, 0.00281446, 0.00227668],
[ 0.00790458, 0.00602065, 0.00821311, 0.00788086],
[ 0.00727708, 0.00890354, 0.01185306, 0.00880289],
[ 0.00073471, -0.00148558, 0.00091246, -0.00060734],
[ 0.01349254, 0.01021229, 0.01257114, 0.0108043 ],
[ 0.02372086, 0.02297548, 0.02319355, 0.01668251],
[ 0.02945446, 0.02684999, 0.02926323, 0.02902177],
[ 0.02817995, 0.02719994, 0.02788556, 0.02784433],
[ 0.03695807, 0.02922592, 0.03577212, 0.03119914],
[ 0.0326502 , 0.03730632, 0.03973274, 0.03421444],
[ 0.03706311, 0.03287932, 0.03660752, 0.03450554],
[ 0.03988483, 0.03439958, 0.03887289, 0.03616712],
[ 0.03017368, 0.03424257, 0.03532363, 0.02857177],
[ 0.03855516, 0.03455752, 0.03746458, 0.0367456 ],
[ 0.03832089, 0.0379483 , 0.03854882, 0.03796184],
[ 0.04709962, 0.04396957, 0.04603146, 0.04606553],
[ 0.04521404, 0.04305011, 0.04424058, 0.04147213],
[ 0.03685116, 0.0401923 , 0.04021992, 0.0358785 ],
[ 0.0474408 , 0.04420525, 0.04624286, 0.04387761],
[ 0.04987269, 0.04401355, 0.04889452, 0.04502898],
[ 0.02905999, 0.03601027, 0.03544558, 0.02994194],
[ 0.02782917, 0.02909307, 0.0291289 , 0.02555896],
[ 0.03008661, 0.02874931, 0.02939694, 0.02995715]])
array([[ 0. , 0. , 0. , 0. ],
[ 0.00178892, 0.00581323, 0.00405187, 0.00398177],
[ 0.00257908, 0.00747371, 0.0046732 , 0.00338094],
[ 0.00203675, 0.00348608, 0.00232347, 0.00213218],
[-0.00692453, -0.00505548, -0.00605573, -0.00678109],
[-0.00181333, 0.00209729, 0.00050953, 0.0007319 ],
[ 0.00512789, 0.00635381, 0.00589577, 0.00632745],
[ 0.00450212, 0.00923766, 0.00952735, 0.00724806],
[-0.00202223, -0.0011549 , -0.00138809, -0.00214768],
[ 0.01070045, 0.01054684, 0.01024379, 0.00924638],
[ 0.0209006 , 0.02331426, 0.02084178, 0.01511553],
[ 0.0266184 , 0.02719005, 0.02689751, 0.02743577],
[ 0.0253474 , 0.02754012, 0.02552301, 0.02626014],
[ 0.03410133, 0.02956677, 0.03339144, 0.02960979],
[ 0.02980533, 0.03764984, 0.03734295, 0.03262044],
[ 0.03420609, 0.03322137, 0.03422491, 0.03291109],
[ 0.03702003, 0.03474214, 0.03648508, 0.03457011],
[ 0.02733564, 0.03458508, 0.03294398, 0.02698647],
[ 0.03569402, 0.03490013, 0.03508 , 0.0351477 ],
[ 0.0354604 , 0.03829204, 0.03616176, 0.03636206],
[ 0.04421495, 0.0443153 , 0.0436272 , 0.04445326],
[ 0.04233456, 0.04339553, 0.04184044, 0.03986694],
[ 0.03399472, 0.04053678, 0.03782902, 0.03428193],
[ 0.04455518, 0.04455105, 0.04383812, 0.04226871],
[ 0.04698038, 0.<br/>
**当LSTMPredictStock/config.json内的epochs大于1时出现的bug**
=====================================================
### 我如何发现这个问题
在详细阅读完README后,我开始自己进行模型的训练。在注意到提供的数据对于每一支股票都只训练了1个epoch且loss大概为0.0020-0.1,自然而然我就想把epoch调高。通过在LSTMPredictStock/config.json内的training - epochs调成2,程序报错:
[Model] Training Started
[Model] 2 epochs, 8 batch size, 29 batches per epoch
2025-03-09 22:12:11.276718: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2025-03-09 22:12:11.276967: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 1996205000 Hz
Epoch 1/2
29/29 [==============================] - 2s 30ms/step - loss: 0.0024
Epoch 2/2
2025-03-09 22:12:13.764939: W tensorflow/core/framework/op_kernel.cc:1751] Invalid argument: TypeError: generator
yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was [array([[ 0. , 0. , 0. , 0. ],
[ 0.00276253, -0.00033106, 0.00230376, 0.00154365],
[ 0.00455639, 0.00548025, 0.00636496, 0.00553156],
[ 0.00534874, 0.00714018, 0.00698772, 0.00492981],
[ 0.00480491, 0.00315387, 0.00463258, 0.00367912],
[-0.00418113, -0.00538486, -0.00376593, -0.00524791],
[ 0.00094419, 0.00176553, 0.00281446, 0.00227668],
[ 0.00790458, 0.00602065, 0.00821311, 0.00788086],
[ 0.00727708, 0.00890354, 0.01185306, 0.00880289],
[ 0.00073471, -0.00148558, 0.00091246, -0.00060734],
[ 0.01349254, 0.01021229, 0.01257114, 0.0108043 ],
[ 0.02372086, 0.02297548, 0.02319355, 0.01668251],
[ 0.02945446, 0.02684999, 0.02926323, 0.02902177],
[ 0.02817995, 0.02719994, 0.02788556, 0.02784433],
[ 0.03695807, 0.02922592, 0.03577212, 0.03119914],
[ 0.0326502 , 0.03730632, 0.03973274, 0.03421444],
[ 0.03706311, 0.03287932, 0.03660752, 0.03450554],
[ 0.03988483, 0.03439958, 0.03887289, 0.03616712],
[ 0.03017368, 0.03424257, 0.03532363, 0.02857177],
[ 0.03855516, 0.03455752, 0.03746458, 0.0367456 ],
[ 0.03832089, 0.0379483 , 0.03854882, 0.03796184],
[ 0.04709962, 0.04396957, 0.04603146, 0.04606553],
[ 0.04521404, 0.04305011, 0.04424058, 0.04147213],
[ 0.03685116, 0.0401923 , 0.04021992, 0.0358785 ],
[ 0.0474408 , 0.04420525, 0.04624286, 0.04387761],
[ 0.04987269, 0.04401355, 0.04889452, 0.04502898],
[ 0.02905999, 0.03601027, 0.03544558, 0.02994194],
[ 0.02782917, 0.02909307, 0.0291289 , 0.02555896],
[ 0.03008661, 0.02874931, 0.02939694, 0.02995715]])
array([[ 0. , 0. , 0. , 0. ],
[ 0.00178892, 0.00581323, 0.00405187, 0.00398177],
[ 0.00257908, 0.00747371, 0.0046732 , 0.00338094],
[ 0.00203675, 0.00348608, 0.00232347, 0.00213218],
[-0.00692453, -0.00505548, -0.00605573, -0.00678109],
[-0.00181333, 0.00209729, 0.00050953, 0.0007319 ],
[ 0.00512789, 0.00635381, 0.00589577, 0.00632745],
[ 0.00450212, 0.00923766, 0.00952735, 0.00724806],
[-0.00202223, -0.0011549 , -0.00138809, -0.00214768],
[ 0.01070045, 0.01054684, 0.01024379, 0.00924638],
[ 0.0209006 , 0.02331426, 0.02084178, 0.01511553],
[ 0.0266184 , 0.02719005, 0.02689751, 0.02743577],
[ 0.0253474 , 0.02754012, 0.02552301, 0.02626014],
[ 0.03410133, 0.02956677, 0.03339144, 0.02960979],
[ 0.02980533, 0.03764984, 0.03734295, 0.03262044],
[ 0.03420609, 0.03322137, 0.03422491, 0.03291109],
[ 0.03702003, 0.03474214, 0.03648508, 0.03457011],
[ 0.02733564, 0.03458508, 0.03294398, 0.02698647],
[ 0.03569402, 0.03490013, 0.03508 , 0.0351477 ],
[ 0.0354604 , 0.03829204, 0.03616176, 0.03636206],
[ 0.04421495, 0.0443153 , 0.0436272 , 0.04445326],
[ 0.04233456, 0.04339553, 0.04184044, 0.03986694],
[ 0.03399472, 0.04053678, 0.03782902, 0.03428193],
[ 0.04455518, 0.04455105, 0.04383812, 0.04226871],
[ 0.04698038, 0.