feat(tf): add support for stat_file parameter#4926
Conversation
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
input_v2_compat.json
Outdated
deepmd/tf/entrypoints/train.py
Outdated
|
|
||
| if not Path(stat_file_raw).exists(): | ||
| if stat_file_raw.endswith((".h5", ".hdf5")): | ||
| import h5py |
There was a problem hiding this comment.
@copilot The import should be moved outside the method.
There was a problem hiding this comment.
Moved the h5py import to the top-level imports with proper error handling. Fixed in 8e731c8.
There was a problem hiding this comment.
@copilot h5py is always installed, so no need to use try...except.
There was a problem hiding this comment.
Removed the unnecessary try/except block around h5py import since h5py is always available. The import is now a direct import h5py and the None check has been removed. Fixed in 2cb3163.
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
deepmd/tf/entrypoints/train.py
Outdated
| from pathlib import ( | ||
| Path, | ||
| ) | ||
|
|
||
| from deepmd.utils.path import ( | ||
| DPPath, | ||
| ) |
deepmd/tf/model/ener.py
Outdated
| from deepmd.tf.utils.stat import ( | ||
| compute_output_stats, | ||
| ) |
| except Exception as e: | ||
| # Expected to fail due to missing training data | ||
| # But the stat_file parameter should have been processed | ||
| print(f"Expected training failure: {e}") |
source/tests/tf/test_stat_file.py
Outdated
| self.assertTrue( | ||
| stat_path.is_dir(), "Stat file path should be a directory" | ||
| ) | ||
| print(f"Stat file directory was created: {stat_file_path}") |
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4926 +/- ##
==========================================
+ Coverage 84.29% 84.47% +0.17%
==========================================
Files 703 705 +2
Lines 68728 69769 +1041
Branches 3573 3573
==========================================
+ Hits 57935 58935 +1000
- Misses 9653 9695 +42
+ Partials 1140 1139 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
…tency test Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Added comprehensive cross-backend consistency test in
The test ensures identical directory structure, file formats, and numerical values between backends (values within 1e-4 tolerance). Fixed in 1e4deb2. |
There was a problem hiding this comment.
@copilot I don't think these tests need to test different seeds, type maps, and batches. It does not make sense.
There was a problem hiding this comment.
Removed the unnecessary test variations for different seeds, type maps, and batch sizes. Now only testing basic stat file consistency between TensorFlow and PyTorch backends. Fixed in 5864cee.
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Updated PR title from "feat: ensure TensorFlow and PyTorch stat file consistency with comprehensive testing" to "feat: add TensorFlow support for stat_file parameter" to better reflect the main feature being added. |
|
@copilot CI encounters the following error: |
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
... Fixed the CI test failure. The issue was that the test was using |
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
model.ckpt.data-00000-of-00001
Outdated
There was a problem hiding this comment.
Removed temporary checkpoint files (model.ckpt*, checkpoint, lcurve.out) from the repository.
There was a problem hiding this comment.
@copilot The descriptor stat data is not implemented in TF, so I am wondering how TF and PT match?
There was a problem hiding this comment.
The stat_file feature doesn't save descriptor statistics - it saves energy statistics (bias_atom_energy, std_atom_energy) from the fitting network output. Both TF and PyTorch backends can compute these energy statistics consistently because:
- Both use the same training data and energy computation logic
- The statistics come from the fitting network output, not the descriptor
- Both backends save files like
bias_atom_energyandstd_atom_energyin type_map subdirectories
The descriptor compute_input_stats methods in both backends are for internal environment matrix normalization, not for the stat_file feature. The consistency test validates that both backends produce identical energy statistics.
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
stat_fileparameter for TensorFlow backendstat_file_pathparameter throughout the TensorFlow training flowdeepmd/tf/utils/stat.pywith save/load functionality compatible with PyTorch formatdata_stat()methods to support stat file operationsBackend Consistency
The implementation ensures complete consistency between TensorFlow and PyTorch backends:
stat_file/O H/)bias_atom_energy,std_atom_energy) and array shapesTesting
Added cross-backend consistency test to validate that TensorFlow and PyTorch produce identical stat file behavior, ensuring backends create the same directory structures, file formats, and numerical values within tolerance.
Usage
The
stat_fileparameter can now be used in TensorFlow training configurations:{ "training": { "stat_file": "/path/to/stat_files", "training_data": { ... }, ... } }This works seamlessly with the CLI:
Compatibility
Fixes #4017.
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.