// ============================================================
// Zorro DLL: LibTorch demo using ONLY main() (no run())
// ============================================================
#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif
#ifndef NOMINMAX
#define NOMINMAX
#endif
// #include <windows.h> // <-- REMOVED (not needed anymore if you don't call WinAPI here)
// --- 1) Include LibTorch FIRST ---
#include <torch/torch.h>
#include <vector>
#include <random>
#include <tuple>
#include <algorithm>
#include <cstdio>
#include <exception>
#include <cstdlib>
// --- 2) Include Zorro AFTER torch, rename Zorro's 'at' to avoid conflict ---
#define at zorro_at
#ifdef LOG
#undef LOG
#endif
#include <zorro.h>
#undef at
// --- 3) Cleanup common macro landmines ---
#ifdef min
#undef min
#endif
#ifdef max
#undef max
#endif
#ifdef ref
#undef ref
#endif
#ifdef swap
#undef swap
#endif
#ifdef abs
#undef abs
#endif
// ---------- Tiny 1D environment ----------
struct LineWorld {
int n;
int pos = 0;
int maxSteps;
int steps = 0;
LineWorld(int n_, int maxSteps_) : n(n_), maxSteps(maxSteps_) {}
torch::Tensor reset() {
pos = 0;
steps = 0;
return state();
}
torch::Tensor state() const {
auto s = torch::zeros({n}, torch::kFloat32);
s.index_put_({pos}, 1.0f);
return s;
}
struct StepResult {
torch::Tensor next_state;
float reward;
bool done;
};
StepResult step(int action) {
if (action == 0) pos = std::max(0, pos - 1);
else pos = std::min(n - 1, pos + 1);
steps++;
bool reached = (pos == n - 1);
bool timeout = (steps >= maxSteps);
bool done = reached || timeout;
float reward = reached ? 1.0f : -0.01f;
return { state(), reward, done };
}
};
// ---------- Simple replay buffer ----------
struct Transition {
torch::Tensor s;
int a;
float r;
torch::Tensor ns;
bool done;
};
struct ReplayBuffer {
std::vector<Transition> data;
size_t capacity;
size_t idx = 0;
bool filled = false;
ReplayBuffer(size_t cap) : capacity(cap) {
data.resize(capacity);
}
void push(const Transition& t) {
data[idx] = t;
idx = (idx + 1) % capacity;
if (idx == 0) filled = true;
}
size_t size() const {
return filled ? capacity : idx;
}
bool can_sample(size_t batch) const {
return size() >= batch;
}
std::vector<Transition> sample(size_t batch, std::mt19937& rng) const {
std::uniform_int_distribution<size_t> dist(0, size() - 1);
std::vector<Transition> out;
out.reserve(batch);
for (size_t i = 0; i < batch; i++)
out.push_back(data[dist(rng)]);
return out;
}
};
// ---------- Q Network ----------
struct QNetImpl : torch::nn::Module {
torch::nn::Linear l1{nullptr}, l2{nullptr}, l3{nullptr};
QNetImpl(int stateDim, int actionDim) {
l1 = register_module("l1", torch::nn::Linear(stateDim, 64));
l2 = register_module("l2", torch::nn::Linear(64, 64));
l3 = register_module("l3", torch::nn::Linear(64, actionDim));
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(l1(x));
x = torch::relu(l2(x));
x = l3(x);
return x;
}
};
TORCH_MODULE(QNet);
static void hard_update(QNet& dst, QNet& src)
{
torch::NoGradGuard ng;
auto sp = src->parameters();
auto dp = dst->parameters();
for (size_t i = 0; i < sp.size(); i++)
dp[i].copy_(sp[i]);
}
struct Agent {
QNet q;
QNet target;
torch::optim::Adam opt;
ReplayBuffer rb;
float gamma = 0.99f;
float eps = 1.0f;
float epsMin = 0.05f;
float epsDecay = 0.995f;
Agent(int stateDim, int actionDim, float lr, size_t replayCap)
: q(QNet(stateDim, actionDim)),
target(QNet(stateDim, actionDim)),
opt(q->parameters(), torch::optim::AdamOptions(lr)),
rb(replayCap)
{
hard_update(target, q);
}
int act(const torch::Tensor& s, std::mt19937& rng) {
std::uniform_real_distribution<float> u(0.0f, 1.0f);
if (u(rng) < eps) {
std::uniform_int_distribution<int> aDist(0, 1);
return aDist(rng);
}
torch::NoGradGuard ng;
auto qvals = q->forward(s);
return qvals.argmax().item<int>();
}
void decay_epsilon() {
eps *= epsDecay;
if (eps < epsMin) eps = epsMin;
}
void train_step(size_t batchSize, std::mt19937& rng) {
if (!rb.can_sample(batchSize)) return;
auto batch = rb.sample(batchSize, rng);
std::vector<torch::Tensor> ss, nss;
std::vector<int64_t> aa;
std::vector<float> rr, dd;
for (const auto& t : batch) {
ss.push_back(t.s);
nss.push_back(t.ns);
aa.push_back((int64_t)t.a);
rr.push_back(t.r);
dd.push_back(t.done ? 1.0f : 0.0f);
}
auto S = torch::stack(ss);
auto NS = torch::stack(nss);
auto A = torch::tensor(aa, torch::kInt64);
auto R = torch::tensor(rr, torch::kFloat32);
auto D = torch::tensor(dd, torch::kFloat32);
auto qvals = q->forward(S);
auto q_sa = qvals.gather(1, A.unsqueeze(1)).squeeze(1);
torch::Tensor next_q;
{
torch::NoGradGuard ng;
next_q = std::get<0>(target->forward(NS).max(1));
}
auto y = R + gamma * next_q * (1.0f - D);
auto loss = torch::mse_loss(q_sa, y);
opt.zero_grad();
loss.backward();
opt.step();
}
void update_target() { hard_update(target, q); }
};
// ============================================================
// Zorro calls this exported main() once (no run() used)
// ============================================================
extern "C" DLLFUNC int main()
{
// (Optional: keep thread limits even without DLL path management)
// _putenv_s("OMP_NUM_THREADS", "1");
// _putenv_s("MKL_NUM_THREADS", "1");
// _putenv_s("KMP_DUPLICATE_LIB_OK", "TRUE");
// torch::set_num_threads(1);
// torch::set_num_interop_threads(1);
// Make printing work in Zorro-hosted DLLs
setvbuf(stdout, nullptr, _IONBF, 0);
FILE* f = fopen("Log\\mt6409_torch.txt", "a");
if (f) { fprintf(f, "Started MT6409\n"); fflush(f); }
try {
printf("MT6409 Torch demo starting (main)...\n");
torch::manual_seed(0);
const int stateDim = 9;
const int actionDim = 2;
LineWorld env(stateDim, 30);
std::mt19937 rng(123);
Agent agent1(stateDim, actionDim, 1e-3f, 5000);
Agent agent2(stateDim, actionDim, 1e-3f, 5000);
const int episodes = 200;
const int targetUpdateEvery = 20;
const size_t batchSize = 64;
float avg1 = 0.0f, avg2 = 0.0f;
for (int ep = 1; ep <= episodes; ep++) {
// Agent 1
{
auto s = env.reset();
bool done = false;
float total = 0.0f;
while (!done) {
int a = agent1.act(s, rng);
auto res = env.step(a);
agent1.rb.push({s, a, res.reward, res.next_state, res.done});
agent1.train_step(batchSize, rng);
total += res.reward;
s = res.next_state;
done = res.done;
}
agent1.decay_epsilon();
avg1 = 0.95f * avg1 + 0.05f * total;
}
// Agent 2
{
auto s = env.reset();
bool done = false;
float total = 0.0f;
while (!done) {
int a = agent2.act(s, rng);
auto res = env.step(a);
agent2.rb.push({s, a, res.reward, res.next_state, res.done});
agent2.train_step(batchSize, rng);
total += res.reward;
s = res.next_state;
done = res.done;
}
agent2.decay_epsilon();
avg2 = 0.95f * avg2 + 0.05f * total;
}
if (ep % targetUpdateEvery == 0) {
agent1.update_target();
agent2.update_target();
}
if (ep % 25 == 0) {
printf("Episode %d | Agent1 avgR=%.4f eps=%.4f | Agent2 avgR=%.4f eps=%.4f\n",
ep, avg1, agent1.eps, avg2, agent2.eps);
if (f) {
fprintf(f, "Episode %d | Agent1 avgR=%.4f eps=%.4f | Agent2 avgR=%.4f eps=%.4f\n",
ep, avg1, agent1.eps, avg2, agent2.eps);
fflush(f);
}
}
}
printf("Done.\n");
if (f) { fprintf(f, "Done MT6409\n"); fclose(f); }
return 0;
}
catch (const c10::Error& e) {
printf("TORCH c10::Error: %s\n", e.what());
}
catch (const std::exception& e) {
printf("std::exception: %s\n", e.what());
}
catch (...) {
printf("Unknown exception in main().\n");
}
if (f) fclose(f);
return 1;
}