Program Listing for File DataModifier.h

Return to documentation for file (source/api_cc/include/DataModifier.h)

#pragma once

#include "DeepPot.h"

namespace deepmd {
class DipoleChargeModifier {
 public:
  DipoleChargeModifier();
  DipoleChargeModifier(const std::string& model,
                       const int& gpu_rank = 0,
                       const std::string& name_scope = "");
  ~DipoleChargeModifier();
  void init(const std::string& model,
            const int& gpu_rank = 0,
            const std::string& name_scope = "");
  void print_summary(const std::string& pre) const;

 public:
  template <typename VALUETYPE>
  void compute(std::vector<VALUETYPE>& dfcorr_,
               std::vector<VALUETYPE>& dvcorr_,
               const std::vector<VALUETYPE>& dcoord_,
               const std::vector<int>& datype_,
               const std::vector<VALUETYPE>& dbox,
               const std::vector<std::pair<int, int>>& pairs,
               const std::vector<VALUETYPE>& delef_,
               const int nghost,
               const InputNlist& lmp_list);
  double cutoff() const {
    assert(inited);
    return rcut;
  };
  int numb_types() const {
    assert(inited);
    return ntypes;
  };
  std::vector<int> sel_types() const {
    assert(inited);
    return sel_type;
  };

 private:
  tensorflow::Session* session;
  std::string name_scope, name_prefix;
  int num_intra_nthreads, num_inter_nthreads;
  tensorflow::GraphDef* graph_def;
  bool inited;
  double rcut;
  int dtype;
  double cell_size;
  int ntypes;
  std::string model_type;
  std::vector<int> sel_type;
  template <class VT>
  VT get_scalar(const std::string& name) const;
  template <class VT>
  void get_vector(std::vector<VT>& vec, const std::string& name) const;
  template <typename MODELTYPE, typename VALUETYPE>
  void run_model(std::vector<VALUETYPE>& dforce,
                 std::vector<VALUETYPE>& dvirial,
                 tensorflow::Session* session,
                 const std::vector<std::pair<std::string, tensorflow::Tensor>>&
                     input_tensors,
                 const AtomMap& atommap,
                 const int nghost);
};
}  // namespace deepmd