diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2022-01-27 22:06:42 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2022-01-27 22:06:42 +0000 |
commit | 6f8fc217eaa12bf657be1c6468ed9938d10168b3 (patch) | |
tree | a1fd89b864d9b93e2ad68fe1dcf7afee2e3c8d76 /llvm/lib/Analysis/TFUtils.cpp | |
parent | 77fc4c146f0870ffb09c1afb823ccbe742c5e6ff (diff) | |
download | src-6f8fc217eaa12bf657be1c6468ed9938d10168b3.tar.gz src-6f8fc217eaa12bf657be1c6468ed9938d10168b3.zip |
Vendor import of llvm-project main llvmorg-14-init-17616-g024a1fab5c35.vendor/llvm-project/llvmorg-14-init-17616-g024a1fab5c35
Diffstat (limited to 'llvm/lib/Analysis/TFUtils.cpp')
-rw-r--r-- | llvm/lib/Analysis/TFUtils.cpp | 48 |
1 files changed, 39 insertions, 9 deletions
diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp index 3d10479c4544..26bc63983b4e 100644 --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/Utils/TFUtils.h" +#include "llvm/Support/Base64.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/JSON.h" @@ -22,6 +23,7 @@ #include "llvm/Support/Path.h" #include "llvm/Support/raw_ostream.h" +#include "google/protobuf/struct.pb.h" #include "google/protobuf/text_format.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_experimental.h" @@ -72,6 +74,14 @@ TFStatusPtr createTFStatus() { TFSessionOptionsPtr createTFSessionOptions() { return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions); } + +void serialize(const Message &SE, std::string *OutStr) { + if (ProtobufTextMode) { + TextFormat::PrintToString(SE, OutStr); + } else { + *OutStr = SE.SerializeAsString(); + } +} } // namespace namespace llvm { @@ -307,19 +317,13 @@ public: IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {} // flush the logged info to a stream and clear the log contents. - void flush(raw_ostream &OS) { + void flush(std::string *Str) { size_t NrRecords = getNrRecords(); (void)NrRecords; tensorflow::SequenceExample SE; transferLog(SE); assert(isSelfConsistent(SE, NrRecords)); - std::string OutStr; - if (ProtobufTextMode) - google::protobuf::TextFormat::PrintToString(SE, &OutStr); - else - OutStr = SE.SerializeAsString(); - - OS << OutStr; + serialize(SE, Str); } char *addNewTensor(size_t FeatureID) { @@ -567,5 +571,31 @@ char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) { return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID)); } -void Logger::flush(raw_ostream &OS) { LoggerData->flush(OS); } +void Logger::flush(std::string *Str) { LoggerData->flush(Str); } + +void Logger::flush(raw_ostream &OS) { + std::string Buff; + LoggerData->flush(&Buff); + OS << Buff; +} + +void Logger::flushLogs(raw_ostream &OS, + const StringMap<std::unique_ptr<Logger>> &Loggers) { + google::protobuf::Struct Msg; + for (const auto &NamedLogger : Loggers) { + tensorflow::SequenceExample SE; + const auto &Logger = NamedLogger.second; + std::string Unencoded; + if (Logger->LoggerData->getNrRecords() > 0) + Logger->flush(&Unencoded); + + (*Msg.mutable_fields())[NamedLogger.first().str()] + .mutable_string_value() + ->append(ProtobufTextMode ? Unencoded : encodeBase64(Unencoded)); + } + + std::string OutStr; + serialize(Msg, &OutStr); + OS << OutStr; +} #endif // defined(LLVM_HAVE_TF_API) |