1. 早停策略

import torch.nn as nn
import torch.optim as optim
import time
import matplotlib.pyplot as plt
from tqdm import tqdm# Define the MLP model
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(X_train.shape[1], 10)self.relu = nn.ReLU()self.fc2 = nn.Linear(10, 2)  # Binary classificationdef forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return out# Instantiate the model
model = MLP().to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# Training settings
num_epochs = 20000
early_stop_patience = 50  # Epochs to wait for improvement
best_loss = float('inf')
patience_counter = 0
best_epoch = 0
early_stopped = False# Track losses
train_losses = []
test_losses = []
epochs = []# Start training
start_time = time.time()
with tqdm(total=num_epochs, desc="Training Progress", unit="epoch") as pbar:for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)train_loss = criterion(outputs, y_train)train_loss.backward()optimizer.step()# Evaluate on the test setmodel.eval()with torch.no_grad():outputs_test = model(X_test)test_loss = criterion(outputs_test, y_test)if (epoch + 1) % 200 == 0:train_losses.append(train_loss.item())test_losses.append(test_loss.item())epochs.append(epoch + 1)# Early stopping checkif test_loss.item() < best_loss:  # If current test loss is better than the bestbest_loss = test_loss.item()  # Update best lossbest_epoch = epoch + 1  # Update best epochpatience_counter = 0  # Reset counter# Save the best modeltorch.save(model.state_dict(), 'best_model.pth')else:patience_counter += 1if patience_counter >= early_stop_patience:print(f"Early stopping triggered! No improvement for {early_stop_patience} epochs.")print(f"Best test loss was at epoch {best_epoch} with a loss of {best_loss:.4f}")early_stopped = Truebreak  # Stop the training loop# Update the progress barpbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'})# Update progress bar every 1000 epochsif (epoch + 1) % 1000 == 0:pbar.update(1000)# Ensure progress bar reaches 100%
if pbar.n < num_epochs:pbar.update(num_epochs - pbar.n)time_all = time.time() - start_time  # Calculate total training time
print(f'Training time: {time_all:.2f} seconds')# If early stopping occurred, load the best model
if early_stopped:print(f"Loading best model from epoch {best_epoch} for final evaluation...")model.load_state_dict(torch.load('best_model.pth'))# Continue training for 50 more epochs after loading the best model
num_extra_epochs = 50
for epoch in range(num_extra_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)train_loss = criterion(outputs, y_train)train_loss.backward()optimizer.step()# Evaluate on the test setmodel.eval()with torch.no_grad():outputs_test = model(X_test)test_loss = criterion(outputs_test, y_test)train_losses.append(train_loss.item())test_losses.append(test_loss.item())epochs.append(num_epochs + epoch + 1)# Print progress for the extra epochsprint(f"Epoch {num_epochs + epoch + 1}: Train Loss = {train_loss.item():.4f}, Test Loss = {test_loss.item():.4f}")# Plot the loss curves
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()# Evaluate final accuracy on the test set
model.eval()
with torch.no_grad():outputs = model(X_test)_, predicted = torch.max(outputs, 1)correct = (predicted == y_test).sum().item()accuracy = correct / y_test.size(0)print(f'Test Accuracy: {accuracy * 100:.2f}%')

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/86169.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/86169.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/86169.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

零基础搭建Spring AI本地开发环境指南

Spring AI 是一个 Spring 官方团队主导的开源项目&#xff0c;旨在将生成式人工智能&#xff08;Generative AI&#xff09;能力无缝集成到 Spring 应用程序中。它提供了一个统一的、Spring 风格的抽象层&#xff0c;简化了与各种大型语言模型&#xff08;LLMs&#xff09;、嵌…

windows登录系统配置双因子认证的解决方案

在数字化浪潮席卷全球的今天&#xff0c;安全如同氧气般不可或缺。Verizon《2023年数据泄露调查报告》指出&#xff0c;80%的黑客攻击与登录凭证失窃直接相关。当传统密码防护变得千疮百孔&#xff0c;企业如何在身份验证的战场上赢得主动权&#xff1f;答案就藏在"双保险…

Java数据结构——线性表Ⅱ

一、链式存储结构概述 1. 基本概念&#xff08;逻辑分析&#xff09; 核心思想&#xff1a;用指针将离散的存储单元串联成逻辑上连续的线性表 设计动机&#xff1a;解决顺序表 "预先分配空间" 与 "动态扩展" 的矛盾 关键特性&#xff1a; 结点空间动态…

技术基石:SpreadJS 引擎赋能极致体验

在能源行业数字化转型的浪潮中&#xff0c;青岛国瑞信息技术有限公司始终以技术创新为核心驱动力&#xff0c;不断探索前沿技术在能源领域的深度应用。其推出的 RCV 行列视生产数据应用系统之所以能够在行业内脱颖而出&#xff0c;离不开背后强大的技术基石 ——SpreadJS 引擎。…

Typora - Typora 打字机模式

Typora 打字机模式 1、基本介绍 Typora 打字机模式&#xff08;Typewriter Mode&#xff09;是一种专注于当前写作行的功能 打字机模式会自动将正在编辑的行保持在屏幕中央&#xff0c;让用户更集中注意力&#xff0c;类似于传统打字机的体验 2、开启方式 点击 【视图】 -…

3.0 compose学习:MVVM框架+Hilt注解调用登录接口

文章目录 前言&#xff1a;1、添加依赖1.1 在settings.gradle.kts中添加1.2 在应用级的build.gradle.kts添加插件依赖1.3 在module级的build.gradle.kts添加依赖 2、实体类2.1 request2.2 reponse 3、网络请求3.1 ApiService3.2 NetworkModule3.3 拦截器 添加token3.4 Hilt 的 …

git学习资源

动画演示&#xff1a;Learn Git Branching 终极目标&#xff08;能看懂即入门&#xff09;&#xff1a;git 简明指南 Git 教程 | 菜鸟教程

C++ 第二阶段:模板编程 - 第一节:函数模板与类模板

目录 一、模板编程的核心概念 1.1 什么是模板编程&#xff1f; 二、函数模板详解 2.1 函数模板的定义与使用 2.1.1 基本语法 2.1.2 示例&#xff1a;通用交换函数 2.1.3 类型推导规则 2.2 函数模板的注意事项 2.2.1 普通函数与函数模板的调用规则 2.2.2 隐式类型转换…

Docker 报错“x509: certificate signed by unknown authority”的排查与解决实录

目录 &#x1f527;Docker 报错“x509: certificate signed by unknown authority”的排查与解决实录 &#x1f4cc; 问题背景 &#x1f9ea; 排查过程 步骤 1&#xff1a;确认加速器地址是否可访问 步骤 2&#xff1a;检查 Docker 是否真的使用了镜像加速器 步骤 3&…

达梦以及其他图形化安装没反应或者报错No more handles [gtk_init_check() failed]

本人安装问题和解决步骤如下&#xff0c;仅供参考 执行 DMInstall.bin 报错 按照网上大部分解决方案 export DISPLAY:0.0 xhost 重新执行 DMInstall.bin&#xff0c;无报错也无反应 安装xclock测试也是同样效果&#xff0c;无报错也无反应 最开始猜测可能是连接工具问题&a…

项目节奏不一致时,如何保持全局平衡

项目节奏不一致时&#xff0c;如何保持全局平衡的关键在于&#xff1a;构建跨项目协调机制、合理配置资源、建立共享节奏看板、优先明确战略驱动、引入缓冲与预警机制。其中&#xff0c;构建跨项目协调机制尤为关键&#xff0c;它能将各项目的排期、优先级和风险实时联动&#…

macOS - 安装微软雅黑字体

文章目录 1、下载资源2、安装3、查看字体 app4、卸载字体 macOS 中打开 Windows 传输过来的文件的时候&#xff0c;经常会提示 xxx 字体缺失。下面以安装 微软雅黑字体为例。 1、下载资源 https://github.com/BronyaCat/Win-Fonts-For-Mac 2、安装 双击 Fonts 文件夹下的 msy…

ArkUI-X资源分类与访问

应用开发过程中&#xff0c;经常需要用到颜色、字体、间距、图片等资源&#xff0c;在不同的设备或配置中&#xff0c;这些资源的值可能不同。 应用资源&#xff1a;借助资源文件能力&#xff0c;开发者在应用中自定义资源&#xff0c;自行管理这些资源在不同的设备或配置中的…

11-StarRocks故障诊断FAQ

StarRocks故障诊断FAQ 概述 本文档整理了StarRocks故障诊断过程中常见的问题和解决方案,涵盖了故障排查、日志分析、性能诊断、问题定位等各个方面,帮助用户快速定位和解决StarRocks相关问题。 故障排查FAQ Q1: 如何排查连接故障? A: 连接故障排查方法: 1. 网络连通性…

敏捷项目管理怎么做?4大主流方法论对比及工具适配方案

在传统瀑布式项目管理中&#xff0c;需求定义、设计、开发、测试等环节如同工业流水线般严格线性推进&#xff0c;展现出强大的流程控制能力。不过今天的软件迭代周期已压缩至周级乃至日级&#xff0c;瀑布式管理难以应对需求的快速变化&#xff0c;敏捷式项目管理则以“小步快…

解决YOLO模型从Python迁移到C++时目标漏检问题——跨语言部署中的关键陷阱与解决方案

问题背景 当我们将Python训练的YOLO模型部署到C环境时&#xff0c;常遇到部分目标漏检问题。这通常源于预处理/后处理差异、数据类型隐式转换或模型转换误差。本文通过完整案例解析核心问题并提供可落地的解决方案。 一、常见原因分析 预处理不一致 Python常用OpenCV&#xff…

【2025CCF中国开源大会】开放注册与会议通知(第二轮)

点击蓝字 关注我们 CCF Opensource Development Committee 2025 CCF中国开源大会 由中国计算机学会主办的 2025 CCF中国开源大会&#xff08;CCF ChinaOSC&#xff09;拟于 2025年8月2日-3日 在上海召开。本届大会以“蓄势引领、众行致远”为主题&#xff0c;由上海交通大学校长…

本地聊天室

测试版还没测试过&#xff0c;后面更新不会继续开源&#xff0c;有问题自行修复 开发环境: PHP版本7.2 Swoole扩展 本地服务器环境&#xff08;如XAMPP、MAMP&#xff09; 功能说明: 注册/登录系统&#xff0c;支持本地用户数据存储 ​ 发送文本、图片和语音消息 ​ 实…

golang学习随便记x-调试与杂类(待续)

编译与调试 调试时从终端键盘输入 调试带有需要用户键盘输入的程序时&#xff0c;VSCode报错&#xff1a;Unable to process evaluate: debuggee is running&#xff0c;因为调试器不知道具体是哪个终端输入。需要配置启动文件 .vscode/launch.json 类似如下&#xff08;注意…

MultipartFile、File 和 Mat

1. MultipartFile (来自 Spring Web) 用途&#xff1a; 代表通过 multipart 形式提交&#xff08;通常是 HTTP POST 请求&#xff09;接收到的文件。 它是 Spring Web 中用于处理 Web 客户端文件上传的核心接口。 关键特性&#xff1a; 抽象&#xff1a; 这是一个接口&#xf…