Program Listing for File DataModifier.h
↰ Return to documentation for file (source/api_cc/include/DataModifier.h)
#pragma once
#include "DeepPot.h"
namespace deepmd {
class DipoleChargeModifier {
public:
DipoleChargeModifier();
DipoleChargeModifier(const std::string& model,
const int& gpu_rank = 0,
const std::string& name_scope = "");
~DipoleChargeModifier();
void init(const std::string& model,
const int& gpu_rank = 0,
const std::string& name_scope = "");
void print_summary(const std::string& pre) const;
public:
template <typename VALUETYPE>
void compute(std::vector<VALUETYPE>& dfcorr_,
std::vector<VALUETYPE>& dvcorr_,
const std::vector<VALUETYPE>& dcoord_,
const std::vector<int>& datype_,
const std::vector<VALUETYPE>& dbox,
const std::vector<std::pair<int, int>>& pairs,
const std::vector<VALUETYPE>& delef_,
const int nghost,
const InputNlist& lmp_list);
double cutoff() const {
assert(inited);
return rcut;
};
int numb_types() const {
assert(inited);
return ntypes;
};
std::vector<int> sel_types() const {
assert(inited);
return sel_type;
};
private:
tensorflow::Session* session;
std::string name_scope, name_prefix;
int num_intra_nthreads, num_inter_nthreads;
tensorflow::GraphDef* graph_def;
bool inited;
double rcut;
int dtype;
double cell_size;
int ntypes;
std::string model_type;
std::vector<int> sel_type;
template <class VT>
VT get_scalar(const std::string& name) const;
template <class VT>
void get_vector(std::vector<VT>& vec, const std::string& name) const;
template <typename MODELTYPE, typename VALUETYPE>
void run_model(std::vector<VALUETYPE>& dforce,
std::vector<VALUETYPE>& dvirial,
tensorflow::Session* session,
const std::vector<std::pair<std::string, tensorflow::Tensor>>&
input_tensors,
const AtomMap& atommap,
const int nghost);
};
} // namespace deepmd