aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/TFUtils.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2022-01-27 22:06:42 +0000
committerDimitry Andric <dim@FreeBSD.org>2022-01-27 22:06:42 +0000
commit6f8fc217eaa12bf657be1c6468ed9938d10168b3 (patch)
treea1fd89b864d9b93e2ad68fe1dcf7afee2e3c8d76 /llvm/lib/Analysis/TFUtils.cpp
parent77fc4c146f0870ffb09c1afb823ccbe742c5e6ff (diff)
downloadsrc-6f8fc217eaa12bf657be1c6468ed9938d10168b3.tar.gz
src-6f8fc217eaa12bf657be1c6468ed9938d10168b3.zip
Vendor import of llvm-project main llvmorg-14-init-17616-g024a1fab5c35.vendor/llvm-project/llvmorg-14-init-17616-g024a1fab5c35
Diffstat (limited to 'llvm/lib/Analysis/TFUtils.cpp')
-rw-r--r--llvm/lib/Analysis/TFUtils.cpp48
1 files changed, 39 insertions, 9 deletions
diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp
index 3d10479c4544..26bc63983b4e 100644
--- a/llvm/lib/Analysis/TFUtils.cpp
+++ b/llvm/lib/Analysis/TFUtils.cpp
@@ -14,6 +14,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/Utils/TFUtils.h"
+#include "llvm/Support/Base64.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/JSON.h"
@@ -22,6 +23,7 @@
#include "llvm/Support/Path.h"
#include "llvm/Support/raw_ostream.h"
+#include "google/protobuf/struct.pb.h"
#include "google/protobuf/text_format.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_experimental.h"
@@ -72,6 +74,14 @@ TFStatusPtr createTFStatus() {
TFSessionOptionsPtr createTFSessionOptions() {
return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions);
}
+
+void serialize(const Message &SE, std::string *OutStr) {
+ if (ProtobufTextMode) {
+ TextFormat::PrintToString(SE, OutStr);
+ } else {
+ *OutStr = SE.SerializeAsString();
+ }
+}
} // namespace
namespace llvm {
@@ -307,19 +317,13 @@ public:
IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
// flush the logged info to a stream and clear the log contents.
- void flush(raw_ostream &OS) {
+ void flush(std::string *Str) {
size_t NrRecords = getNrRecords();
(void)NrRecords;
tensorflow::SequenceExample SE;
transferLog(SE);
assert(isSelfConsistent(SE, NrRecords));
- std::string OutStr;
- if (ProtobufTextMode)
- google::protobuf::TextFormat::PrintToString(SE, &OutStr);
- else
- OutStr = SE.SerializeAsString();
-
- OS << OutStr;
+ serialize(SE, Str);
}
char *addNewTensor(size_t FeatureID) {
@@ -567,5 +571,31 @@ char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) {
return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID));
}
-void Logger::flush(raw_ostream &OS) { LoggerData->flush(OS); }
+void Logger::flush(std::string *Str) { LoggerData->flush(Str); }
+
+void Logger::flush(raw_ostream &OS) {
+ std::string Buff;
+ LoggerData->flush(&Buff);
+ OS << Buff;
+}
+
+void Logger::flushLogs(raw_ostream &OS,
+ const StringMap<std::unique_ptr<Logger>> &Loggers) {
+ google::protobuf::Struct Msg;
+ for (const auto &NamedLogger : Loggers) {
+ tensorflow::SequenceExample SE;
+ const auto &Logger = NamedLogger.second;
+ std::string Unencoded;
+ if (Logger->LoggerData->getNrRecords() > 0)
+ Logger->flush(&Unencoded);
+
+ (*Msg.mutable_fields())[NamedLogger.first().str()]
+ .mutable_string_value()
+ ->append(ProtobufTextMode ? Unencoded : encodeBase64(Unencoded));
+ }
+
+ std::string OutStr;
+ serialize(Msg, &OutStr);
+ OS << OutStr;
+}
#endif // defined(LLVM_HAVE_TF_API)