/*
 * Copyright (C) 2021 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <fcntl.h>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "perfetto/base/logging.h"
#include "perfetto/ext/base/file_utils.h"
#include "perfetto/ext/base/getopt.h"
#include "perfetto/ext/base/scoped_file.h"
#include "perfetto/ext/base/string_utils.h"
#include "perfetto/ext/base/version.h"
#include "protos/perfetto/config/trace_config.gen.h"
#include "src/proto_utils/txt_to_pb.h"
#include "src/protozero/filtering/filter_util.h"
#include "src/protozero/filtering/message_filter.h"
#include "src/protozero/filtering/string_filter.h"

namespace perfetto {
namespace proto_filter {
namespace {

const char kUsage[] =
    R"(Usage: proto_filter [options]

-s --schema-in:      Path to the root .proto file. Required for most operations.
                     Filtering options (passthrough, filter_string, semantic_type) are read
                     from [(perfetto.protos.proto_filter)] annotations on each field.
-I --proto_path:     Extra include directory for proto includes. If omitted assumed CWD.
-r --root_message:   Fully qualified name for the root proto message (e.g. perfetto.protos.Trace)
                     If omitted the first message defined in the schema will be used.
-i --msg_in:         Path of a binary-encoded proto message which will be filtered.
-o --msg_out:        Path of the binary-encoded filtered proto message written in output.
-c --config_in:      Path of a TraceConfig textproto (note: only trace_filter field is considered).
-f --filter_in:      Path of a filter bytecode file previously generated by this tool.
-F --filter_out:     Path of the filter bytecode file generated from the --schema-in definition.
-T --filter_oct_out: Like --filter_out, but emits a octal-escaped C string suitable for .pbtx.
   --overlay_v54_out: Path of the v54 overlay bytecode file generated from the --schema-in definition.
   --overlay_v54_oct_out: Like --overlay_v54_out, but emits a octal-escaped C string suitable for .pbtx.
-d --dedupe:         Minimize filter size by deduping leaf messages with same field ids.
   --min-bytecode-parser: Minimum bytecode parser version to target (v1, v2, v54).
                     Default: v2.

Example usage:

# Convert a .proto schema file into a diff-friendly list of messages/fields>

  proto_filter -r perfetto.protos.Trace -s protos/perfetto/trace/trace.proto

# Generate the filter bytecode from a .proto schema

  proto_filter -r perfetto.protos.Trace -s protos/perfetto/trace/trace.proto \
               -F /tmp/bytecode [--dedupe]

# List the used/filtered fields from a trace file

  proto_filter -r perfetto.protos.Trace -s protos/perfetto/trace/trace.proto \
               -i test/data/example_android_trace_30s.pb -f /tmp/bytecode

# Filter a trace using a filter bytecode

  proto_filter -i test/data/example_android_trace_30s.pb -f /tmp/bytecode \
               -o /tmp/filtered_trace

# Filter a trace using a TraceConfig textproto

  proto_filter -i test/data/example_android_trace_30s.pb \
               -c /tmp/config.textproto \
               -o /tmp/filtered_trace

# Show which fields are allowed by a filter bytecode

  proto_filter -r perfetto.protos.Trace -s protos/perfetto/trace/trace.proto \
               -f /tmp/bytecode
)";

using TraceFilter = protos::gen::TraceConfig::TraceFilter;
using StringFilterRule = TraceFilter::StringFilterRule;

std::optional<protozero::StringFilter::Policy> ConvertPolicy(
    TraceFilter::StringFilterPolicy policy) {
  switch (policy) {
    case TraceFilter::SFP_UNSPECIFIED:
      return std::nullopt;
    case TraceFilter::SFP_MATCH_REDACT_GROUPS:
      return protozero::StringFilter::Policy::kMatchRedactGroups;
    case TraceFilter::SFP_ATRACE_MATCH_REDACT_GROUPS:
      return protozero::StringFilter::Policy::kAtraceMatchRedactGroups;
    case TraceFilter::SFP_MATCH_BREAK:
      return protozero::StringFilter::Policy::kMatchBreak;
    case TraceFilter::SFP_ATRACE_MATCH_BREAK:
      return protozero::StringFilter::Policy::kAtraceMatchBreak;
    case TraceFilter::SFP_ATRACE_REPEATED_SEARCH_REDACT_GROUPS:
      return protozero::StringFilter::Policy::kAtraceRepeatedSearchRedactGroups;
  }
  return std::nullopt;
}

protozero::StringFilter::SemanticTypeMask ConvertSemanticTypes(
    const StringFilterRule& rule) {
  protozero::StringFilter::SemanticTypeMask mask;
  for (const auto& type : rule.semantic_type()) {
    auto semantic_type = static_cast<uint32_t>(type);
    if (semantic_type < protozero::StringFilter::SemanticTypeMask::kLimit) {
      mask.Set(semantic_type);
    }
  }
  return mask;
}

// Writes binary data to a file. Returns true on success.
bool WriteBytecode(const std::string& path,
                   const std::string& data,
                   const char* description) {
  auto fd = base::OpenFile(path, O_WRONLY | O_TRUNC | O_CREAT, 0644);
  if (!fd) {
    PERFETTO_ELOG("Could not open %s path %s", description, path.c_str());
    return false;
  }
  PERFETTO_LOG("Writing %s (%zu bytes) into %s", description, data.size(),
               path.c_str());
  base::WriteAll(*fd, data.data(), data.size());
  return true;
}

// Writes bytecode as an octal-escaped string suitable for .pbtx files.
// Returns true on success.
bool WriteBytecodeOctal(const std::string& path,
                        const std::string& data,
                        const char* field_name,
                        const char* description) {
  auto fd = base::OpenFile(path, O_WRONLY | O_TRUNC | O_CREAT, 0644);
  if (!fd) {
    PERFETTO_ELOG("Could not open %s path %s", description, path.c_str());
    return false;
  }
  std::string oct_str;
  oct_str.reserve(data.size() * 4 + 64);
  oct_str.append("trace_filter {\n  ");
  oct_str.append(field_name);
  oct_str.append(": \"");
  for (char c : data) {
    uint8_t octet = static_cast<uint8_t>(c);
    char buf[5]{'\\', '0', '0', '0', 0};
    for (uint8_t i = 0; i < 3; ++i) {
      buf[3 - i] = static_cast<char>('0' + static_cast<uint8_t>(octet) % 8);
      octet /= 8;
    }
    oct_str.append(buf);
  }
  oct_str.append("\"\n}\n");
  PERFETTO_LOG("Writing %s (%zu bytes) into %s", description, oct_str.size(),
               path.c_str());
  base::WriteAll(*fd, oct_str.data(), oct_str.size());
  return true;
}

// Long-only options (no short code). Values must not conflict with ASCII chars.
enum LongOnlyOption {
  kV54Out = 256,
  kV54OctOut,
  kMinBytecodeParser,
};

// Parses version string (v1, v2, v54) to BytecodeVersion enum.
std::optional<protozero::FilterBytecodeGenerator::BytecodeVersion>
ParseBytecodeVersion(const std::string& version_str) {
  using BytecodeVersion = protozero::FilterBytecodeGenerator::BytecodeVersion;
  if (version_str == "v1")
    return BytecodeVersion::kV1;
  if (version_str == "v2")
    return BytecodeVersion::kV2;
  if (version_str == "v54")
    return BytecodeVersion::kV54;
  return std::nullopt;
}

int Main(int argc, char** argv) {
  static const option long_options[] = {
      {"help", no_argument, nullptr, 'h'},
      {"version", no_argument, nullptr, 'v'},
      {"dedupe", no_argument, nullptr, 'd'},
      {"proto_path", required_argument, nullptr, 'I'},
      {"schema_in", required_argument, nullptr, 's'},
      {"root_message", required_argument, nullptr, 'r'},
      {"msg_in", required_argument, nullptr, 'i'},
      {"msg_out", required_argument, nullptr, 'o'},
      {"config_in", required_argument, nullptr, 'c'},
      {"filter_in", required_argument, nullptr, 'f'},
      {"filter_out", required_argument, nullptr, 'F'},
      {"filter_oct_out", required_argument, nullptr, 'T'},
      {"overlay_v54_out", required_argument, nullptr, kV54Out},
      {"overlay_v54_oct_out", required_argument, nullptr, kV54OctOut},
      {"min-bytecode-parser", required_argument, nullptr, kMinBytecodeParser},
      {nullptr, 0, nullptr, 0}};

  std::string msg_in;
  std::string msg_out;
  std::string config_in;
  std::string filter_in;
  std::string schema_in;
  std::string filter_out;
  std::string filter_oct_out;
  std::string overlay_v54_out;
  std::string overlay_v54_oct_out;
  std::string proto_path;
  std::string root_message_arg;
  std::string min_bytecode_parser = "v2";  // Default to v2 for compatibility
  bool dedupe = false;

  for (;;) {
    int option =
        getopt_long(argc, argv, "hvdI:s:r:i:o:f:F:T:c:", long_options, nullptr);

    if (option == -1)
      break;  // EOF.

    if (option == 'v') {
      printf("%s\n", base::GetVersionString());
      exit(0);
    }

    if (option == 'd') {
      dedupe = true;
      continue;
    }

    if (option == 'I') {
      proto_path = optarg;
      continue;
    }

    if (option == 's') {
      schema_in = optarg;
      continue;
    }

    if (option == 'c') {
      config_in = optarg;
      continue;
    }

    if (option == 'r') {
      root_message_arg = optarg;
      continue;
    }

    if (option == 'i') {
      msg_in = optarg;
      continue;
    }

    if (option == 'o') {
      msg_out = optarg;
      continue;
    }

    if (option == 'f') {
      filter_in = optarg;
      continue;
    }

    if (option == 'F') {
      filter_out = optarg;
      continue;
    }

    if (option == 'T') {
      filter_oct_out = optarg;
      continue;
    }

    if (option == kV54Out) {
      overlay_v54_out = optarg;
      continue;
    }

    if (option == kV54OctOut) {
      overlay_v54_oct_out = optarg;
      continue;
    }

    if (option == kMinBytecodeParser) {
      min_bytecode_parser = optarg;
      continue;
    }

    if (option == 'h') {
      fprintf(stdout, kUsage);
      exit(0);
    }

    fprintf(stderr, kUsage);
    exit(1);
  }

  if (msg_in.empty() && filter_in.empty() && schema_in.empty()) {
    fprintf(stderr, kUsage);
    return 1;
  }

  if (!filter_in.empty() && !config_in.empty()) {
    fprintf(stderr, kUsage);
    return 1;
  }

  std::string msg_in_data;
  if (!msg_in.empty()) {
    PERFETTO_LOG("Loading proto-encoded message from %s", msg_in.c_str());
    if (!base::ReadFile(msg_in, &msg_in_data)) {
      PERFETTO_ELOG("Could not open message file %s", msg_in.c_str());
      return 1;
    }
  }

  protozero::FilterUtil filter;
  if (!schema_in.empty()) {
    PERFETTO_LOG("Loading proto schema from %s", schema_in.c_str());
    if (!filter.LoadMessageDefinition(schema_in, root_message_arg,
                                      proto_path)) {
      PERFETTO_ELOG("Failed to parse proto schema from %s", schema_in.c_str());
      return 1;
    }
    if (dedupe)
      filter.Dedupe();
  }

  protozero::MessageFilter msg_filter;
  std::string filter_data;
  std::string filter_data_src;
  std::string overlay_data;  // For v54 overlay bytecode.
  if (!filter_in.empty()) {
    PERFETTO_LOG("Loading filter bytecode from %s", filter_in.c_str());
    if (!base::ReadFile(filter_in, &filter_data)) {
      PERFETTO_ELOG("Could not open filter file %s", filter_in.c_str());
      return 1;
    }
    filter_data_src = filter_in;
  } else if (!config_in.empty()) {
    PERFETTO_LOG("Loading filter bytecode and rules from %s",
                 config_in.c_str());
    std::string config_data;
    if (!base::ReadFile(config_in, &config_data)) {
      PERFETTO_ELOG("Could not open config file %s", config_in.c_str());
      return 1;
    }
    auto res = TraceConfigTxtToPb(config_data, config_in);
    if (!res.ok()) {
      fprintf(stderr, "%s\n", res.status().c_message());
      return 1;
    }

    std::vector<uint8_t>& config_bytes = res.value();
    protos::gen::TraceConfig config;
    config.ParseFromArray(config_bytes.data(), config_bytes.size());

    const auto& trace_filter = config.trace_filter();

    // Load base string filter chain.
    for (const auto& rule : trace_filter.string_filter_chain().rules()) {
      auto opt_policy = ConvertPolicy(rule.policy());
      if (!opt_policy) {
        PERFETTO_ELOG("Unknown string filter policy %d", rule.policy());
        return 1;
      }
      msg_filter.string_filter().AddRule(
          *opt_policy, rule.regex_pattern(), rule.atrace_payload_starts_with(),
          rule.name(), ConvertSemanticTypes(rule));
    }

    // Load v54 string filter chain. Rules with matching names will replace
    // existing rules; others will be appended.
    for (const auto& rule : trace_filter.string_filter_chain_v54().rules()) {
      auto opt_policy = ConvertPolicy(rule.policy());
      if (!opt_policy) {
        PERFETTO_ELOG("Unknown string filter policy %d", rule.policy());
        return 1;
      }
      msg_filter.string_filter().AddRule(
          *opt_policy, rule.regex_pattern(), rule.atrace_payload_starts_with(),
          rule.name(), ConvertSemanticTypes(rule));
    }

    filter_data = trace_filter.bytecode_v2().empty()
                      ? trace_filter.bytecode()
                      : trace_filter.bytecode_v2();
    filter_data_src = config_in;

    if (trace_filter.has_bytecode_overlay_v54()) {
      overlay_data = trace_filter.bytecode_overlay_v54();
    }
  } else if (!schema_in.empty()) {
    PERFETTO_LOG("Generating filter bytecode from %s", schema_in.c_str());

    // Parse the bytecode version.
    auto bytecode_version = ParseBytecodeVersion(min_bytecode_parser);
    if (!bytecode_version.has_value()) {
      PERFETTO_ELOG("Invalid bytecode version: %s (expected v1, v2, or v54)",
                    min_bytecode_parser.c_str());
      return 1;
    }

    auto result = filter.GenerateFilterBytecode(*bytecode_version);
    filter_data = std::move(result.bytecode);
    overlay_data = std::move(result.v54_overlay);
    filter_data_src = schema_in;
  }

  if (!filter_data.empty()) {
    const void* data = filter_data.data();
    const void* overlay_ptr =
        overlay_data.empty() ? nullptr : overlay_data.data();
    if (!msg_filter.LoadFilterBytecode(data, filter_data.size(), overlay_ptr,
                                       overlay_data.size())) {
      PERFETTO_ELOG("Failed to parse filter bytecode from %s",
                    filter_data_src.c_str());
      return 1;
    }
  }

  // Write the filter bytecode in output.
  if (!filter_out.empty()) {
    if (!WriteBytecode(filter_out, filter_data, "filter bytecode"))
      return 1;
  }
  if (!filter_oct_out.empty()) {
    if (!WriteBytecodeOctal(filter_oct_out, filter_data, "bytecode",
                            "filter bytecode"))
      return 1;
  }

  // Write the v54 overlay bytecode in output (if any).
  if (!overlay_v54_out.empty() && !overlay_data.empty()) {
    if (!WriteBytecode(overlay_v54_out, overlay_data, "v54 overlay bytecode"))
      return 1;
  }
  if (!overlay_v54_oct_out.empty() && !overlay_data.empty()) {
    if (!WriteBytecodeOctal(overlay_v54_oct_out, overlay_data,
                            "bytecode_overlay_v54", "v54 overlay bytecode"))
      return 1;
  }

  // Apply the filter to the input message (if any).
  std::vector<uint8_t> msg_filtered_data;
  if (!msg_in.empty()) {
    PERFETTO_LOG("Applying filter %s to proto message %s",
                 filter_data_src.c_str(), msg_in.c_str());
    msg_filter.enable_field_usage_tracking(true);
    auto res = msg_filter.FilterMessage(msg_in_data.data(), msg_in_data.size());
    if (res.error)
      PERFETTO_FATAL("Filtering failed");
    msg_filtered_data.insert(msg_filtered_data.end(), res.data.get(),
                             res.data.get() + res.size);
  }

  // Write out the filtered message.
  if (!msg_out.empty()) {
    PERFETTO_LOG("Writing filtered proto bytes (%zu bytes) into %s",
                 msg_filtered_data.size(), msg_out.c_str());
    auto fd = base::OpenFile(msg_out, O_WRONLY | O_TRUNC | O_CREAT, 0644);
    base::WriteAll(*fd, msg_filtered_data.data(), msg_filtered_data.size());
  }

  if (!msg_in.empty()) {
    const auto& field_usage_map = msg_filter.field_usage();
    for (const auto& it : field_usage_map) {
      const std::string& field_path_varint = it.first;
      int32_t num_occurrences = it.second;
      std::string path_str = filter.LookupField(field_path_varint);
      printf("%-100s %s %d\n", path_str.c_str(),
             num_occurrences < 0 ? "DROP" : "PASS", std::abs(num_occurrences));
    }
  } else if (!schema_in.empty()) {
    filter.PrintAsText(!filter_data.empty() ? std::make_optional(filter_data)
                                            : std::nullopt);
  }

  if ((!filter_out.empty() || !filter_oct_out.empty()) && !dedupe) {
    PERFETTO_ELOG(
        "Warning: looks like you are generating a filter without --dedupe. For "
        "production use cases, --dedupe can make the output bytecode "
        "significantly smaller.");
  }
  return 0;
}

}  // namespace
}  // namespace proto_filter
}  // namespace perfetto

int main(int argc, char** argv) {
  return perfetto::proto_filter::Main(argc, argv);
}
