-
Notifications
You must be signed in to change notification settings - Fork 915
Description
Hello, I noticed that when I used L-BFGS-B after training with Adam, it only trained for about twenty seconds before ending, and this was very fast, ending within seconds. I'd like to know why this is happening.
import tensorflow_probability as tfp
from deepxde.backend import tf
import deepxde as dde
import numpy as np
import pandas as pd
import time
import os
import random
import math
soil geometry
soil = [0.0, -1.0, 0.0, 1.0] # [zmin, zmax, Tinitial, Tfinal]
WRC: theta(psi)
def theta_function(h, thetar, thetas, alpha, n, m):
term2 = 1 + tf.pow(-alpha * h, n)
term3 = tf.pow(term2, -m)
result = thetar + (thetas - thetar) * term3
thetaSvg_broadcast = tf.ones_like(result) * thetaSvg
result = tf.where(h > 0, thetaSvg_broadcast, result)
return result
HCF: K(psi)
def K_function(h, thetar, thetas, alpha, n, m, Ks):
theta_h = theta_function(h, thetar, thetas, alpha, n, m)
term1 = tf.pow((theta_h - thetar) / (thetas - thetar), 0.5)
term2 = 1 - tf.pow(1 - tf.pow((theta_h - thetar) / (thetas - thetar), 1/m), m)
result = Ks * term1 * tf.pow(term2, 2)
# 修复:同样修复K_function中的形状问题
ksvg_broadcast = tf.ones_like(result) * ksvg
result = tf.where(h > 0, ksvg_broadcast, result)
return result
soil parameters: loam [m and day] units
nvg = tf.constant(1.56, dtype=tf.float32)
mvg = 1 - 1/nvg
ksvg = tf.constant(0.2496, dtype=tf.float32)
alphavg = tf.constant(3.6, dtype=tf.float32)
thetaRvg = tf.constant(0.078, dtype=tf.float32)
thetaSvg = tf.constant(0.43, dtype=tf.float32)
############################ IBC##############:
#water:
psi_ic = -1. # initial condition for pressure head
psi_surface = -0.2 # soil surface pressure head
#solutes:
c_initial= 0. # initial condition for all solutes
c1_inlet = 1. # # soil surface concentration for solute 1: NH4+
c2_inlet = 0.01 # # soil surface concentration for solute 2: NO2-
c3_inlet = 0.2 # # soil surface concentration for solute 3: NO3-
#PINNs structure
num_layers = 5
num_neurons = 50
number_random = 111
layers = np.concatenate([[2], num_neurons*np.ones(num_layers), [1]]).astype(int).tolist()
#iteratons for 0water and solute
itwater = 1000
weights for the loss functions
constant_ic, constant_up, constant_dw, constant_res = 10, 1, 1, 1
#function for generating collocation points
def get_collocations(soil, n):
z = np.random.uniform(soil[0], soil[1], n).reshape(-1, 1)
t = np.random.uniform(soil[2], soil[3], n).reshape(-1, 1)
return t, z
size for each collocation points
n_res, n_ic, n_up, n_dw = 10000, 100, 1000, 100
np.random.seed(0)
residual points
t_res, z_res = get_collocations(soil, n_res)
initial collocation points
t_ic, z_ic = get_collocations(list(np.append(soil[0:3],0)), n_ic)
collocation points for soil surface boundary
t_up, z_up = get_collocations([soil[0],soil[0],soil[2],soil[3]], n_up)
collocation points for soil bottom boundary
t_dw, z_dw = get_collocations([soil[1],soil[1],soil[2],soil[3]], n_dw)
############################# water solver ###############################################
class water:
def __init__(self, layers, LAA, ):
self.LAA = LAA
self.weights_psi, self.biases_psi, self.A_psi = self.initialize_NN(layers)
# tf placeholder : empty variables
[self.t_res_tf, self.z_res_tf,self.t_ic_tf, self.z_ic_tf, \
self.t_up_tf, self.z_up_tf, \
self.z_dw_tf, self.t_dw_tf]= [tf.compat.v1.placeholder(tf.float32, shape=[None, 1]) for _ in range(8)]
# tf session
self.sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True,log_device_placement=True))
# prediction from PINNs
self.psi_pred, self.residual_pred = self.net_res(self.t_res_tf, self.z_res_tf)
self.psi_ic_pred= self.net_ic(self.t_ic_tf, self.z_ic_tf)
self.psi_up_pred= self.net_ic(self.t_up_tf, self.z_up_tf)
self.q_dw_pred= self.net_q_dw(self.t_dw_tf, self.z_dw_tf)
self.psi_up = tf.fill(tf.shape(self.psi_up_pred), psi_surface) #up BC
self.psi_ic_exact = tf.fill(tf.shape(self.psi_ic_pred), psi_ic) #IC
# loss function
self.loss_res = tf.reduce_mean(tf.square(self.residual_pred))
self.loss_ic = tf.reduce_mean(tf.square(self.psi_ic_pred - self.psi_ic_exact))
self.loss_up = tf.reduce_mean(tf.square(self.psi_up_pred - self.psi_up))
self.loss_dw = tf.reduce_mean(tf.square(self.q_dw_pred))
self.loss = constant_res * self.loss_res + constant_ic* self.loss_ic \
+ constant_up* self.loss_up \
+ constant_dw * self.loss_dw
# L-BFGS-B method
self.optimizer = dde.optimizers.tensorflow_compat_v1.scipy_optimizer.ScipyOptimizerInterface(self.loss,
method = 'L-BFGS-B',
options = {'maxiter': 50000,
'maxfun': 50000,
'maxcor': 50,
'maxls': 50,
'ftol' : 1e-12,
'gtol' : 1e-10})
# define optimizer with learning rate schedule
self.global_step = tf.Variable(0, trainable = False)
self.starter_learning_rate = 1e-3
self.learning_rate = tf.compat.v1.train.exponential_decay(self.starter_learning_rate, self.global_step,
1000, 0.90, staircase=False)
self.train_op_Adam = tf.compat.v1.train.AdamOptimizer(self.learning_rate).minimize(self.loss, global_step=self.global_step)
init = tf.compat.v1.global_variables_initializer()
self.sess.run(init)
# tf.saver
self.saver = tf.compat.v1.train.Saver()
# total loss
self.loss_total = []
self.lbfgs_iter = 0
self.lbfgs_nit = 0
def initialize_NN(self, layers):
num_layers = len(layers)
weights = []
biases = []
A = []
for l in range(0, num_layers-1):
in_dim = layers[l]
out_dim = layers[l+1]
xavier_stddev = np.sqrt(2/(in_dim + out_dim))
W = tf.Variable(tf.random.truncated_normal([in_dim, out_dim], stddev = xavier_stddev),dtype=tf.float32, trainable=True)
b = tf.Variable(np.zeros([1, out_dim]), dtype=tf.float32, trainable=True)
weights.append(W)
biases.append(b)
a = tf.Variable(0.05, dtype=tf.float32, trainable=True)
A.append(a)
return weights, biases, A
def net_psi(self, X, weights, biases, A):
num_layers = len(weights) + 1
H = X
for l in range(0, num_layers-1):
W = weights[l]
b = biases[l]
H = tf.add(tf.matmul(H, W), b)
# activation
if l < num_layers-2:
if self.LAA:
H = tf.tanh(20 *A[l]*H)
else:
H = tf.tanh(H)
H = tf.clip_by_value(H, -40.0, 40.0)
return -tf.exp(H)
def net_res(self, t, z):
X = tf.concat([t, z],1)
psi = self.net_psi(X, self.weights_psi, self.biases_psi, self.A_psi)
theta= theta_function(psi, thetaRvg, thetaSvg, alphavg, nvg, mvg)
K= K_function(psi, thetaRvg, thetaSvg, alphavg, nvg, mvg, ksvg)
theta_t = tf.gradients(theta, t)[0]
psi_z = tf.gradients(psi, z)[0]
q_exact=-K*(psi_z+1)
q_z = tf.gradients(q_exact, z)[0]
# residual loss
res_richards = theta_t + q_z
return psi, res_richards
def net_ic(self, t, z):
X = tf.concat([t, z],1)
psi = self.net_psi(X, self.weights_psi, self.biases_psi, self.A_psi)
return psi
def net_q_dw(self, t, z):
X = tf.concat([t, z],1)
psi = self.net_psi(X, self.weights_psi, self.biases_psi, self.A_psi)
psi_z = tf.gradients(psi, z)[0]
return psi_z
def net_water(self, t, z, w, b, a):
X = tf.concat([t, z],1)
psi = self.net_psi(X, w, b, a)
theta =theta_function(psi, thetaRvg,thetaSvg, alphavg, nvg, mvg)
K= K_function(psi, thetaRvg, thetaSvg, alphavg, nvg, mvg, ksvg)
psi_z = tf.gradients(psi, z)[0]
q=-K*(psi_z+1)
return theta, q
def train(self, N_iter, batch = False, batch_size = 500):
tf_dict = {self.t_res_tf: t_res,
self.z_res_tf: z_res,
self.t_ic_tf: t_ic,
self.z_ic_tf: z_ic,
self.t_up_tf: t_up,
self.z_up_tf: z_up,
self.t_dw_tf: t_dw,
self.z_dw_tf: z_dw}
start_time = time.time()
print("Adams epochs:")
for it in range(N_iter):
self.sess.run(self.train_op_Adam, tf_dict)
if it % 10 == 0:
elapsed = time.time() - start_time
loss_value, loss_res_value, loss_ic_value, loss_up_value, loss_dw_value = self.sess.run([self.loss, self.loss_res, self.loss_ic, self.loss_up, self.loss_dw], tf_dict)
print("-" * 120)
print('Epoch: %d, Loss: %.3e, Loss_r: %.3e, Loss_ic: %.3e, Loss_up: %.3e, Loss_dw: %.3e, Time: %.2f' %
(it, loss_value, loss_res_value, loss_ic_value, loss_up_value, loss_dw_value, elapsed))
self.loss_total.append(loss_value)
start_time = time.time()
# L-BFGS-B
print("L-BFGS-B epochs:")
tf_dict_check = tf_dict
loss_val, res_val = self.sess.run([self.loss, self.loss_res], tf_dict_check)
print("Pre-LBFGS check - loss: %.6e, loss_res: %.6e" % (loss_val, res_val))
self.optimizer.minimize(self.sess,
feed_dict = tf_dict,
fetches = [self.loss],
loss_callback=self.callback,
)
# the final loss value is computed
loss_value = self.sess.run(self.loss, tf_dict)
def callback(self, loss):
self.lbfgs_iter += 1
print(f'L-BFGS-B Iteration: {self.lbfgs_iter}, Loss: {loss}')
def predict(self, t_star, z_star):
tf_dict = {self.t_res_tf: t_star,
self.z_res_tf: z_star, self.t_up_tf: t_star, self.z_up_tf: z_star}
psi = self.sess.run(self.psi_pred, tf_dict)
weights_psi = self.sess.run(self.weights_psi)
biases_psi = self.sess.run(self.biases_psi)
a_psi = self.sess.run(self.A_psi)
theta = self.sess.run(theta_function(psi, thetaRvg, thetaSvg, alphavg,nvg, mvg))
total_loss = self.loss_total
return psi, weights_psi, biases_psi, a_psi, theta, total_loss
Epoch: 960, Loss: 6.216e-04, Loss_r: 1.832e-04, Loss_ic: 1.635e-05, Loss_up: 2.730e-04, Loss_dw: 1.936e-06, Time: 0.18
Epoch: 970, Loss: 4.229e-03, Loss_r: 1.947e-04, Loss_ic: 3.702e-04, Loss_up: 2.975e-04, Loss_dw: 3.507e-05, Time: 0.17
Epoch: 980, Loss: 1.203e-03, Loss_r: 1.796e-04, Loss_ic: 7.367e-05, Loss_up: 2.802e-04, Loss_dw: 6.227e-06, Time: 0.17
Epoch: 990, Loss: 6.621e-04, Loss_r: 1.815e-04, Loss_ic: 2.327e-05, Loss_up: 2.452e-04, Loss_dw: 2.720e-06, Time: 0.17
L-BFGS-B epochs:
Pre-LBFGS check - loss: 6.299498e-04, loss_res: 1.734043e-04
L-BFGS-B Iteration: 1, Loss: 0.0006299503147602081
L-BFGS-B Iteration: 2, Loss: 200666.5625
L-BFGS-B Iteration: 3, Loss: 3.4199745655059814
L-BFGS-B Iteration: 4, Loss: 0.0005860932869836688
L-BFGS-B Iteration: 5, Loss: 0.0005019070813432336
L-BFGS-B Iteration: 6, Loss: 0.0005020918906666338
L-BFGS-B Iteration: 7, Loss: 0.0005017896764911711
L-BFGS-B Iteration: 8, Loss: 0.0005020083044655621
L-BFGS-B Iteration: 9, Loss: 0.0005017738440074027
L-BFGS-B Iteration: 10, Loss: 0.0005020084208808839
L-BFGS-B Iteration: 11, Loss: 0.0005018889205530286
L-BFGS-B Iteration: 12, Loss: 0.000501766917295754
L-BFGS-B Iteration: 13, Loss: 0.000501766917295754
L-BFGS-B Iteration: 14, Loss: 0.0005016588256694376
L-BFGS-B Iteration: 15, Loss: 0.0005016599898226559
L-BFGS-B Iteration: 16, Loss: 0.0005016588256694376
L-BFGS-B Iteration: 17, Loss: 0.0005021147662773728
L-BFGS-B Iteration: 18, Loss: 0.0005017010844312608
L-BFGS-B Iteration: 19, Loss: 0.0005021606921218336
L-BFGS-B Iteration: 20, Loss: 0.0005020289099775255
L-BFGS-B Iteration: 21, Loss: 0.0005016594077460468
L-BFGS-B Iteration: 22, Loss: 0.0005016588256694376
L-BFGS-B Iteration: 23, Loss: 0.0005016588256694376
L-BFGS-B Iteration: 24, Loss: 0.0005016588256694376
L-BFGS-B Iteration: 25, Loss: 0.0005016588256694376
L-BFGS-B Iteration: 26, Loss: 0.0005016588256694376
tf.compat.v1.set_random_seed(0) # TensorFlow's random generator fixed
Richards = water(layers, LAA=True)
Train the water solver
Richards.train(itwater)