This source file is a training demonstration that runs inside a Zorro plug in. Picture a workshop where a learner explores a hallway and learns which direction leads to a goal. The opening section hardens the include environment so Windows macros and Zorro naming do not clash with LibTorch headers. It includes LibTorch first, then includes the Zorro header with a temporary rename so the word at does not collide with the tensor namespace.

Next comes LineWorld, a tiny environment. It stores a current position, a step counter, and a limit on how long an episode may run. Reset returns a fresh state. State builds a one hot tensor so only the current position is marked as active. Step applies an action, clamps the position within bounds, updates the counter, and returns the next state along with a reward and a done flag.

Experience is recorded in a replay buffer, a circular notebook with fixed capacity. Each entry holds a transition: state, action, reward, next state, and termination. Sampling pulls random memories so training is less correlated and more stable.

The learning brain is a neural network module with three linear layers and relu activations. An Agent owns a live network and a target network. The live network is optimized, while the target network is refreshed from time to time. Acting uses an epsilon greedy rule: sometimes it explores randomly, otherwise it chooses the action with the best predicted value. Training stacks a batch of tensors, gathers the value for each chosen action, computes a bootstrapped target from the target network, measures error with mean squared loss, and updates weights using Adam.

Finally, the exported main function is what Zorro calls. It sets up logging, seeds randomness, runs training loops for two agents, prints progress, and wraps everything in exception handling so the host process remains safe.

Small note about DLL loading: if compile64 bat does not add the LibTorch lib directory to PATH for Zorro runtime, Windows cannot find torch and c10 DLLs, and the plug in may fail to load properly.

Code
// ============================================================
//  Zorro DLL: LibTorch demo using ONLY main() (no run())
// ============================================================

#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif

#ifndef NOMINMAX
#define NOMINMAX
#endif

// #include <windows.h>   // <-- REMOVED (not needed anymore if you don't call WinAPI here)

// --- 1) Include LibTorch FIRST ---
#include <torch/torch.h>

#include <vector>
#include <random>
#include <tuple>
#include <algorithm>
#include <cstdio>
#include <exception>
#include <cstdlib>

// --- 2) Include Zorro AFTER torch, rename Zorro's 'at' to avoid conflict ---
#define at zorro_at
#ifdef LOG
#undef LOG
#endif
#include <zorro.h>
#undef at

// --- 3) Cleanup common macro landmines ---
#ifdef min
#undef min
#endif
#ifdef max
#undef max
#endif
#ifdef ref
#undef ref
#endif
#ifdef swap
#undef swap
#endif
#ifdef abs
#undef abs
#endif

// ---------- Tiny 1D environment ----------
struct LineWorld {
  int n;
  int pos = 0;
  int maxSteps;
  int steps = 0;

  LineWorld(int n_, int maxSteps_) : n(n_), maxSteps(maxSteps_) {}

  torch::Tensor reset() {
    pos = 0;
    steps = 0;
    return state();
  }

  torch::Tensor state() const {
    auto s = torch::zeros({n}, torch::kFloat32);
    s.index_put_({pos}, 1.0f);
    return s;
  }

  struct StepResult {
    torch::Tensor next_state;
    float reward;
    bool done;
  };

  StepResult step(int action) {
    if (action == 0) pos = std::max(0, pos - 1);
    else             pos = std::min(n - 1, pos + 1);

    steps++;

    bool reached = (pos == n - 1);
    bool timeout = (steps >= maxSteps);
    bool done = reached || timeout;

    float reward = reached ? 1.0f : -0.01f;
    return { state(), reward, done };
  }
};

// ---------- Simple replay buffer ----------
struct Transition {
  torch::Tensor s;
  int a;
  float r;
  torch::Tensor ns;
  bool done;
};

struct ReplayBuffer {
  std::vector<Transition> data;
  size_t capacity;
  size_t idx = 0;
  bool filled = false;

  ReplayBuffer(size_t cap) : capacity(cap) {
    data.resize(capacity);
  }

  void push(const Transition& t) {
    data[idx] = t;
    idx = (idx + 1) % capacity;
    if (idx == 0) filled = true;
  }

  size_t size() const {
    return filled ? capacity : idx;
  }

  bool can_sample(size_t batch) const {
    return size() >= batch;
  }

  std::vector<Transition> sample(size_t batch, std::mt19937& rng) const {
    std::uniform_int_distribution<size_t> dist(0, size() - 1);
    std::vector<Transition> out;
    out.reserve(batch);
    for (size_t i = 0; i < batch; i++)
      out.push_back(data[dist(rng)]);
    return out;
  }
};

// ---------- Q Network ----------
struct QNetImpl : torch::nn::Module {
  torch::nn::Linear l1{nullptr}, l2{nullptr}, l3{nullptr};

  QNetImpl(int stateDim, int actionDim) {
    l1 = register_module("l1", torch::nn::Linear(stateDim, 64));
    l2 = register_module("l2", torch::nn::Linear(64, 64));
    l3 = register_module("l3", torch::nn::Linear(64, actionDim));
  }

  torch::Tensor forward(torch::Tensor x) {
    x = torch::relu(l1(x));
    x = torch::relu(l2(x));
    x = l3(x);
    return x;
  }
};
TORCH_MODULE(QNet);

static void hard_update(QNet& dst, QNet& src)
{
  torch::NoGradGuard ng;
  auto sp = src->parameters();
  auto dp = dst->parameters();
  for (size_t i = 0; i < sp.size(); i++)
    dp[i].copy_(sp[i]);
}

struct Agent {
  QNet q;
  QNet target;
  torch::optim::Adam opt;
  ReplayBuffer rb;

  float gamma = 0.99f;
  float eps = 1.0f;
  float epsMin = 0.05f;
  float epsDecay = 0.995f;

  Agent(int stateDim, int actionDim, float lr, size_t replayCap)
    : q(QNet(stateDim, actionDim)),
      target(QNet(stateDim, actionDim)),
      opt(q->parameters(), torch::optim::AdamOptions(lr)),
      rb(replayCap)
  {
    hard_update(target, q);
  }

  int act(const torch::Tensor& s, std::mt19937& rng) {
    std::uniform_real_distribution<float> u(0.0f, 1.0f);
    if (u(rng) < eps) {
      std::uniform_int_distribution<int> aDist(0, 1);
      return aDist(rng);
    }
    torch::NoGradGuard ng;
    auto qvals = q->forward(s);
    return qvals.argmax().item<int>();
  }

  void decay_epsilon() {
    eps *= epsDecay;
    if (eps < epsMin) eps = epsMin;
  }

  void train_step(size_t batchSize, std::mt19937& rng) {
    if (!rb.can_sample(batchSize)) return;

    auto batch = rb.sample(batchSize, rng);

    std::vector<torch::Tensor> ss, nss;
    std::vector<int64_t> aa;
    std::vector<float> rr, dd;

    for (const auto& t : batch) {
      ss.push_back(t.s);
      nss.push_back(t.ns);
      aa.push_back((int64_t)t.a);
      rr.push_back(t.r);
      dd.push_back(t.done ? 1.0f : 0.0f);
    }

    auto S  = torch::stack(ss);
    auto NS = torch::stack(nss);
    auto A  = torch::tensor(aa, torch::kInt64);
    auto R  = torch::tensor(rr, torch::kFloat32);
    auto D  = torch::tensor(dd, torch::kFloat32);

    auto qvals = q->forward(S);
    auto q_sa  = qvals.gather(1, A.unsqueeze(1)).squeeze(1);

    torch::Tensor next_q;
    {
      torch::NoGradGuard ng;
      next_q = std::get<0>(target->forward(NS).max(1));
    }
    auto y = R + gamma * next_q * (1.0f - D);

    auto loss = torch::mse_loss(q_sa, y);

    opt.zero_grad();
    loss.backward();
    opt.step();
  }

  void update_target() { hard_update(target, q); }
};

// ============================================================
//  Zorro calls this exported main() once (no run() used)
// ============================================================
extern "C" DLLFUNC int main()
{
  // (Optional: keep thread limits even without DLL path management)
  // _putenv_s("OMP_NUM_THREADS", "1");
  // _putenv_s("MKL_NUM_THREADS", "1");
  // _putenv_s("KMP_DUPLICATE_LIB_OK", "TRUE");
  // torch::set_num_threads(1);
  // torch::set_num_interop_threads(1);

  // Make printing work in Zorro-hosted DLLs
  setvbuf(stdout, nullptr, _IONBF, 0);

  FILE* f = fopen("Log\\mt6409_torch.txt", "a");
  if (f) { fprintf(f, "Started MT6409\n"); fflush(f); }

  try {
    printf("MT6409 Torch demo starting (main)...\n");

    torch::manual_seed(0);

    const int stateDim = 9;
    const int actionDim = 2;

    LineWorld env(stateDim, 30);
    std::mt19937 rng(123);

    Agent agent1(stateDim, actionDim, 1e-3f, 5000);
    Agent agent2(stateDim, actionDim, 1e-3f, 5000);

    const int episodes = 200;
    const int targetUpdateEvery = 20;
    const size_t batchSize = 64;

    float avg1 = 0.0f, avg2 = 0.0f;

    for (int ep = 1; ep <= episodes; ep++) {

      // Agent 1
      {
        auto s = env.reset();
        bool done = false;
        float total = 0.0f;

        while (!done) {
          int a = agent1.act(s, rng);
          auto res = env.step(a);
          agent1.rb.push({s, a, res.reward, res.next_state, res.done});
          agent1.train_step(batchSize, rng);

          total += res.reward;
          s = res.next_state;
          done = res.done;
        }

        agent1.decay_epsilon();
        avg1 = 0.95f * avg1 + 0.05f * total;
      }

      // Agent 2
      {
        auto s = env.reset();
        bool done = false;
        float total = 0.0f;

        while (!done) {
          int a = agent2.act(s, rng);
          auto res = env.step(a);
          agent2.rb.push({s, a, res.reward, res.next_state, res.done});
          agent2.train_step(batchSize, rng);

          total += res.reward;
          s = res.next_state;
          done = res.done;
        }

        agent2.decay_epsilon();
        avg2 = 0.95f * avg2 + 0.05f * total;
      }

      if (ep % targetUpdateEvery == 0) {
        agent1.update_target();
        agent2.update_target();
      }

      if (ep % 25 == 0) {
        printf("Episode %d | Agent1 avgR=%.4f eps=%.4f | Agent2 avgR=%.4f eps=%.4f\n",
               ep, avg1, agent1.eps, avg2, agent2.eps);

        if (f) {
          fprintf(f, "Episode %d | Agent1 avgR=%.4f eps=%.4f | Agent2 avgR=%.4f eps=%.4f\n",
                  ep, avg1, agent1.eps, avg2, agent2.eps);
          fflush(f);
        }
      }
    }

    printf("Done.\n");
    if (f) { fprintf(f, "Done MT6409\n"); fclose(f); }

    return 0;
  }
  catch (const c10::Error& e) {
    printf("TORCH c10::Error: %s\n", e.what());
  }
  catch (const std::exception& e) {
    printf("std::exception: %s\n", e.what());
  }
  catch (...) {
    printf("Unknown exception in main().\n");
  }

  if (f) fclose(f);
  return 1;
}

Last edited by TipmyPip; 2 hours ago.